From bef8aa99245e056c2f6f9c078bf68feaec0f3462 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 5 May 2026 17:21:08 +0200 Subject: [PATCH 01/10] feat(router): add Code Mode MCP server Adds a second MCP server next to the existing per-operation one, exposing two generic tools instead of one-tool-per-operation: - code_mode_search_tools: takes prompts, generates GraphQL ops, registers them in the session catalog, returns their TS signatures. - code_mode_run_js: runs an async arrow against the session catalog inside a JS sandbox (V8 isolate), with tools.(vars) bound to the registered ops. Includes: - router/internal/codemode: harness, sandbox, server, storage, tsgen, yoko (query-gen client), calltrace, observability. - proto/wg/cosmo/code_mode/yoko/v1: Yoko query-gen Connect API. - router/pkg/config: code_mode config block (auth, sandbox limits, query-generation endpoint, named-ops, mutation approval). - demo/code-mode: local federation + Yoko mock + start scripts + MCP client config snippets (Claude Code, Claude Desktop, Codex). - demo/code-mode-connect: alternate demo against an external yoko Connect supergraph (set YOKO_DIR). - router-tests: end-to-end named-ops integration test. --- .gitignore | 4 + Makefile | 41 +- demo/code-mode-connect/README.md | 56 ++ demo/code-mode-connect/router-config.yaml | 82 ++ demo/code-mode-connect/start.sh | 146 ++++ demo/code-mode/.gitignore | 1 + demo/code-mode/Makefile | 30 + demo/code-mode/README.md | 60 ++ demo/code-mode/graph.yaml | 18 + demo/code-mode/mcp-configs/README.md | 122 +++ .../code-mode/mcp-configs/claude.desktop.json | 8 + demo/code-mode/mcp-configs/claude.mcp.json | 8 + demo/code-mode/mcp-configs/codex.toml | 2 + .../sample-prompts/01-search-employees.txt | 1 + .../sample-prompts/02-execute-fetch.txt | 1 + .../sample-prompts/03-multi-tool.txt | 1 + .../04-mutation-not-approved.txt | 1 + demo/code-mode/mcp-stdio-proxy/go.mod | 20 + demo/code-mode/mcp-stdio-proxy/go.sum | 30 + demo/code-mode/mcp-stdio-proxy/main.go | 162 ++++ demo/code-mode/mcp-stdio-proxy/main_test.go | 298 +++++++ demo/code-mode/router-config.yaml | 56 ++ demo/code-mode/run_subgraphs_subset.sh | 13 + demo/code-mode/start.sh | 149 ++++ demo/code-mode/yoko-mock/.gitignore | 3 + demo/code-mode/yoko-mock/README.md | 46 + demo/code-mode/yoko-mock/go.mod | 22 + demo/code-mode/yoko-mock/go.sum | 26 + demo/code-mode/yoko-mock/main.go | 583 +++++++++++++ demo/code-mode/yoko-mock/main_test.go | 169 ++++ demo/code-mode/yoko-mock/schema.graphql | 825 ++++++++++++++++++ .../employees/subgraph/schema.resolvers.go | 22 +- proto/wg/cosmo/code_mode/yoko/v1/yoko.proto | 84 ++ router-tests/code_mode_named_ops_test.go | 621 +++++++++++++ router-tests/go.mod | 7 + router-tests/go.sum | 10 + router-tests/testenv/testenv.go | 17 +- router/core/graph_server.go | 9 + router/core/router.go | 82 ++ router/core/router_config.go | 2 + .../wg/cosmo/code_mode/yoko/v1/yoko.pb.go | 451 ++++++++++ .../yoko/v1/yokov1connect/yoko.connect.go | 142 +++ router/go.mod | 5 + router/go.sum | 10 + .../internal/codemode/calltrace/calltrace.go | 92 ++ .../codemode/calltrace/calltrace_test.go | 51 ++ router/internal/codemode/deps.go | 8 + router/internal/codemode/harness/envelope.go | 203 +++++ .../codemode/harness/envelope_test.go | 61 ++ router/internal/codemode/harness/pipeline.go | 127 +++ .../codemode/harness/pipeline_test.go | 144 +++ router/internal/codemode/harness/shape.go | 97 ++ .../internal/codemode/harness/shape_test.go | 73 ++ router/internal/codemode/harness/transpile.go | 73 ++ .../codemode/harness/transpile_test.go | 61 ++ .../codemode/observability/logging.go | 48 + .../codemode/observability/logging_test.go | 43 + .../codemode/observability/metrics.go | 56 ++ .../codemode/observability/metrics_test.go | 76 ++ .../codemode/observability/tracing.go | 36 + .../codemode/observability/tracing_test.go | 67 ++ router/internal/codemode/sandbox/errors.go | 201 +++++ router/internal/codemode/sandbox/execute.go | 210 +++++ router/internal/codemode/sandbox/headers.go | 44 + router/internal/codemode/sandbox/host.go | 242 +++++ router/internal/codemode/sandbox/preamble.go | 28 + .../codemode/sandbox/preamble_test.go | 94 ++ router/internal/codemode/sandbox/sandbox.go | 168 ++++ .../codemode/sandbox/sandbox_preamble.js | 75 ++ .../internal/codemode/sandbox/sandbox_test.go | 648 ++++++++++++++ router/internal/codemode/sandbox/semaphore.go | 16 + .../internal/codemode/sandbox/validation.go | 98 +++ router/internal/codemode/server/approval.go | 195 +++++ .../internal/codemode/server/approval_test.go | 150 ++++ .../codemode/server/execute_handler.go | 101 +++ .../codemode/server/execute_handler_test.go | 431 +++++++++ router/internal/codemode/server/lifecycle.go | 182 ++++ .../codemode/server/lifecycle_test.go | 206 +++++ .../server/observability_handler_test.go | 180 ++++ .../codemode/server/search_handler.go | 264 ++++++ .../codemode/server/search_handler_test.go | 663 ++++++++++++++ router/internal/codemode/server/server.go | 458 ++++++++++ .../internal/codemode/server/server_test.go | 481 ++++++++++ router/internal/codemode/server/session.go | 34 + .../codemode/storage/memory_backend.go | 354 ++++++++ .../codemode/storage/memory_backend_test.go | 332 +++++++ router/internal/codemode/storage/naming.go | 191 ++++ .../internal/codemode/storage/naming_test.go | 84 ++ .../codemode/storage/redis_backend.go | 362 ++++++++ .../codemode/storage/redis_backend_test.go | 264 ++++++ router/internal/codemode/storage/storage.go | 29 + router/internal/codemode/storage/types.go | 15 + router/internal/codemode/tsgen/bundle_test.go | 138 +++ router/internal/codemode/tsgen/graphql.go | 674 ++++++++++++++ router/internal/codemode/tsgen/tsgen.go | 117 +++ router/internal/codemode/tsgen/tsgen_test.go | 411 +++++++++ router/internal/codemode/tsgen/typescript.go | 102 +++ router/internal/codemode/yoko/client.go | 158 ++++ router/internal/codemode/yoko/client_test.go | 434 +++++++++ router/internal/codemode/yoko/searcher.go | 15 + router/pkg/config/code_mode_config_test.go | 278 ++++++ router/pkg/config/code_mode_validation.go | 23 + router/pkg/config/config.go | 73 +- router/pkg/config/config.schema.json | 144 +++ router/pkg/config/fixtures/full.yaml | 32 + .../pkg/config/testdata/config_defaults.json | 37 + router/pkg/config/testdata/config_full.json | 39 +- 107 files changed, 15205 insertions(+), 32 deletions(-) create mode 100644 demo/code-mode-connect/README.md create mode 100644 demo/code-mode-connect/router-config.yaml create mode 100755 demo/code-mode-connect/start.sh create mode 100644 demo/code-mode/.gitignore create mode 100644 demo/code-mode/Makefile create mode 100644 demo/code-mode/README.md create mode 100644 demo/code-mode/graph.yaml create mode 100644 demo/code-mode/mcp-configs/README.md create mode 100644 demo/code-mode/mcp-configs/claude.desktop.json create mode 100644 demo/code-mode/mcp-configs/claude.mcp.json create mode 100644 demo/code-mode/mcp-configs/codex.toml create mode 100644 demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt create mode 100644 demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt create mode 100644 demo/code-mode/mcp-configs/sample-prompts/03-multi-tool.txt create mode 100644 demo/code-mode/mcp-configs/sample-prompts/04-mutation-not-approved.txt create mode 100644 demo/code-mode/mcp-stdio-proxy/go.mod create mode 100644 demo/code-mode/mcp-stdio-proxy/go.sum create mode 100644 demo/code-mode/mcp-stdio-proxy/main.go create mode 100644 demo/code-mode/mcp-stdio-proxy/main_test.go create mode 100644 demo/code-mode/router-config.yaml create mode 100755 demo/code-mode/run_subgraphs_subset.sh create mode 100755 demo/code-mode/start.sh create mode 100644 demo/code-mode/yoko-mock/.gitignore create mode 100644 demo/code-mode/yoko-mock/README.md create mode 100644 demo/code-mode/yoko-mock/go.mod create mode 100644 demo/code-mode/yoko-mock/go.sum create mode 100644 demo/code-mode/yoko-mock/main.go create mode 100644 demo/code-mode/yoko-mock/main_test.go create mode 100644 demo/code-mode/yoko-mock/schema.graphql create mode 100644 proto/wg/cosmo/code_mode/yoko/v1/yoko.proto create mode 100644 router-tests/code_mode_named_ops_test.go create mode 100644 router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go create mode 100644 router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go create mode 100644 router/internal/codemode/calltrace/calltrace.go create mode 100644 router/internal/codemode/calltrace/calltrace_test.go create mode 100644 router/internal/codemode/deps.go create mode 100644 router/internal/codemode/harness/envelope.go create mode 100644 router/internal/codemode/harness/envelope_test.go create mode 100644 router/internal/codemode/harness/pipeline.go create mode 100644 router/internal/codemode/harness/pipeline_test.go create mode 100644 router/internal/codemode/harness/shape.go create mode 100644 router/internal/codemode/harness/shape_test.go create mode 100644 router/internal/codemode/harness/transpile.go create mode 100644 router/internal/codemode/harness/transpile_test.go create mode 100644 router/internal/codemode/observability/logging.go create mode 100644 router/internal/codemode/observability/logging_test.go create mode 100644 router/internal/codemode/observability/metrics.go create mode 100644 router/internal/codemode/observability/metrics_test.go create mode 100644 router/internal/codemode/observability/tracing.go create mode 100644 router/internal/codemode/observability/tracing_test.go create mode 100644 router/internal/codemode/sandbox/errors.go create mode 100644 router/internal/codemode/sandbox/execute.go create mode 100644 router/internal/codemode/sandbox/headers.go create mode 100644 router/internal/codemode/sandbox/host.go create mode 100644 router/internal/codemode/sandbox/preamble.go create mode 100644 router/internal/codemode/sandbox/preamble_test.go create mode 100644 router/internal/codemode/sandbox/sandbox.go create mode 100644 router/internal/codemode/sandbox/sandbox_preamble.js create mode 100644 router/internal/codemode/sandbox/sandbox_test.go create mode 100644 router/internal/codemode/sandbox/semaphore.go create mode 100644 router/internal/codemode/sandbox/validation.go create mode 100644 router/internal/codemode/server/approval.go create mode 100644 router/internal/codemode/server/approval_test.go create mode 100644 router/internal/codemode/server/execute_handler.go create mode 100644 router/internal/codemode/server/execute_handler_test.go create mode 100644 router/internal/codemode/server/lifecycle.go create mode 100644 router/internal/codemode/server/lifecycle_test.go create mode 100644 router/internal/codemode/server/observability_handler_test.go create mode 100644 router/internal/codemode/server/search_handler.go create mode 100644 router/internal/codemode/server/search_handler_test.go create mode 100644 router/internal/codemode/server/server.go create mode 100644 router/internal/codemode/server/server_test.go create mode 100644 router/internal/codemode/server/session.go create mode 100644 router/internal/codemode/storage/memory_backend.go create mode 100644 router/internal/codemode/storage/memory_backend_test.go create mode 100644 router/internal/codemode/storage/naming.go create mode 100644 router/internal/codemode/storage/naming_test.go create mode 100644 router/internal/codemode/storage/redis_backend.go create mode 100644 router/internal/codemode/storage/redis_backend_test.go create mode 100644 router/internal/codemode/storage/storage.go create mode 100644 router/internal/codemode/storage/types.go create mode 100644 router/internal/codemode/tsgen/bundle_test.go create mode 100644 router/internal/codemode/tsgen/graphql.go create mode 100644 router/internal/codemode/tsgen/tsgen.go create mode 100644 router/internal/codemode/tsgen/tsgen_test.go create mode 100644 router/internal/codemode/tsgen/typescript.go create mode 100644 router/internal/codemode/yoko/client.go create mode 100644 router/internal/codemode/yoko/client_test.go create mode 100644 router/internal/codemode/yoko/searcher.go create mode 100644 router/pkg/config/code_mode_config_test.go create mode 100644 router/pkg/config/code_mode_validation.go diff --git a/.gitignore b/.gitignore index 4419047d6b..eeeeeb87bf 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,7 @@ go.work.sum # Serena .serena/ + +# Local agent / planning artifacts (not for public commits) +.claude/ +docs/superpowers/ \ No newline at end of file diff --git a/Makefile b/Makefile index f7777d71f7..12ec52c48c 100644 --- a/Makefile +++ b/Makefile @@ -116,7 +116,7 @@ generate: make generate-go generate-go: - rm -rf router/gen && buf generate --path proto/wg/cosmo/node --path proto/wg/cosmo/common --path proto/wg/cosmo/graphqlmetrics --template buf.router.go.gen.yaml + rm -rf router/gen && buf generate --path proto/wg/cosmo/node --path proto/wg/cosmo/common --path proto/wg/cosmo/graphqlmetrics --path proto/wg/cosmo/code_mode/yoko/v1 --template buf.router.go.gen.yaml rm -rf graphqlmetrics/gen && buf generate --path proto/wg/cosmo/graphqlmetrics --path proto/wg/cosmo/common --template buf.graphqlmetrics.go.gen.yaml rm -rf connect-go/wg && buf generate --path proto/wg/cosmo/platform --path proto/wg/cosmo/notifications --path proto/wg/cosmo/common --path proto/wg/cosmo/node --template buf.connect-go.go.gen.yaml @@ -187,6 +187,45 @@ docker-build-minikube: docker-build-local run-subgraphs-local: cd demo && go run cmd/all/main.go +CODE_MODE_GOCACHE ?= /tmp/cosmo-code-mode-go-build-cache + +.PHONY: code-mode-demo code-mode-demo-down code-mode-connect-demo code-mode-connect-demo-down + +# Local Code Mode demo: small federation (employees, family, availability, +# mood) + Yoko mock + Cosmo Router with Code Mode and named operations. +# Router GraphQL on :3002, MCP on :5027. Full instructions, prerequisites +# (codex CLI on PATH), and tear-down: demo/code-mode/README.md. +code-mode-demo: + mkdir -p $(CODE_MODE_GOCACHE) + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C router build + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-yoko + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-stdio-proxy + $(MAKE) -C demo/code-mode compose + ./demo/code-mode/start.sh + +# Tear down anything left behind by code-mode-demo. +code-mode-demo-down: + ./demo/code-mode/start.sh --down + +# Runs the code-mode router from source against the yoko Connect supergraph +# (plugins + composed config live in $(YOKO_DIR)). Uses different ports than +# code-mode-demo (router 3012, MCP 5037, yoko-mock 5038) so both can run at +# the same time. Set YOKO_DIR to your local yoko checkout, e.g. +# `make code-mode-connect-demo YOKO_DIR=/path/to/yoko`. +# Full instructions and prerequisites: demo/code-mode-connect/README.md. +YOKO_DIR ?= + +code-mode-connect-demo: + @if [ -z "$(YOKO_DIR)" ]; then echo "YOKO_DIR is required (path to your yoko checkout). See demo/code-mode-connect/README.md" >&2; exit 1; fi + mkdir -p $(CODE_MODE_GOCACHE) + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C router build + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-yoko + YOKO_DIR=$(YOKO_DIR) ./demo/code-mode-connect/start.sh + +# Tear down anything left behind by code-mode-connect-demo. +code-mode-connect-demo-down: + ./demo/code-mode-connect/start.sh --down + sync-go-workspace: cd router && go mod tidy cd demo && make bump-deps diff --git a/demo/code-mode-connect/README.md b/demo/code-mode-connect/README.md new file mode 100644 index 0000000000..cdc494e5ec --- /dev/null +++ b/demo/code-mode-connect/README.md @@ -0,0 +1,56 @@ +# Code Mode Connect Demo + +This demo runs the Code Mode router against an external `yoko` Connect supergraph instead of the local employees federation used by `make code-mode-demo`. +It is useful when you want to exercise Code Mode against a richer set of plugins (Pylon, Linear, PostHog, Circleback, Slack, Notion) served by the `yoko` project. + +It is designed to coexist with `make code-mode-demo`: it uses different ports (router 3012, MCP 5037, yoko-mock 5038), so both demos can run side-by-side. + +## Prerequisites + +- A local checkout of the `yoko` Connect supergraph project (separate repository). + Inside that checkout you must already have built the plugins and composed the supergraph so that the directory contains: + - `config.json` — the composed router config for the yoko supergraph. + - `plugins/` — the plugin binaries the router will load. +- Go (toolchain matching the repo `go.mod`). +- The `codex` CLI on `PATH`, authenticated. The Yoko mock shells out to `codex` for query generation. + +## Run + +From the repository root, set `YOKO_DIR` to your local yoko checkout and run: + +```sh +make code-mode-connect-demo YOKO_DIR=/path/to/yoko +``` + +`YOKO_DIR` is required. +The target fails fast with a clear error if it is missing or if the directory does not contain `config.json`. + +What the target does: + +1. Builds `router/router`. +2. Builds `demo/code-mode/yoko-mock/yoko-mock`. +3. Starts `yoko-mock` on `localhost:5038`. +4. Starts the router with `YOKO_DIR` as its working directory and `demo/code-mode-connect/router-config.yaml` as its config. + The router resolves `config.json` and `plugins/` relative to that CWD, which is why `YOKO_DIR` must be a real composed yoko checkout. + +Expected ports: + +- Router GraphQL: `http://localhost:3012/graphql` +- Code Mode MCP: `http://127.0.0.1:5037/mcp` +- Yoko mock: `http://localhost:5038` + +## Tearing down + +Press Ctrl-C in the foreground terminal. +If anything is left behind, run: + +```sh +make code-mode-connect-demo-down +``` + +The process logs for background services are written to `/tmp/cosmo-code-mode-connect-demo-logs`. + +## Auth headers + +`router-config.yaml` propagates the auth headers expected by the yoko plugins (`X-Pylon-Token`, `X-Linear-Token`, `X-Posthog-Token`, `X-Circleback-Token`, `X-Slack-Token`, `X-Notion-Token`, etc.). +Provide values for these on the request side when calling the router so the plugins can reach their upstream services. diff --git a/demo/code-mode-connect/router-config.yaml b/demo/code-mode-connect/router-config.yaml new file mode 100644 index 0000000000..b2e102fe00 --- /dev/null +++ b/demo/code-mode-connect/router-config.yaml @@ -0,0 +1,82 @@ +version: "1" + +# Different ports than demo/code-mode/router-config.yaml so both demos can run +# side-by-side. See demo/code-mode-connect/start.sh for the matching yoko-mock +# port. +listen_addr: "localhost:3012" +graphql_path: "/graphql" +playground_enabled: false +json_log: false +log_level: info +dev_mode: true +router_registration: false + +# These paths are resolved relative to the router's CWD. start.sh runs the +# router from inside the yoko project dir, so "config.json" and "plugins" are +# the composed supergraph and the plugin binaries that ship with that repo. +execution_config: + file: + path: "config.json" + watch: false + +plugins: + enabled: true + path: "plugins" + +# Header propagation for the yoko plugins. Mirrors yoko/config.yaml so the +# plugins receive the same auth headers when the code-mode router fronts them. +headers: + all: + request: + - op: propagate + named: X-Pylon-Token + - op: propagate + named: X-Linear-Token + - op: propagate + named: X-Linear-Auth-Scheme + - op: propagate + named: X-Posthog-Token + - op: propagate + named: X-Posthog-Host + - op: propagate + named: X-Posthog-Project-Id + - op: propagate + named: X-Circleback-Token + - op: propagate + named: X-Slack-Token + - op: propagate + named: X-Notion-Token + +graphql_metrics: + enabled: false + +telemetry: + tracing: + enabled: false + metrics: + otlp: + enabled: false + prometheus: + enabled: false + +mcp: + enabled: false + graph_name: code-mode-connect-demo + router_url: http://localhost:3012/graphql + session: + stateless: false + code_mode: + enabled: true + server: + # IPv4-only bind, see demo/code-mode/router-config.yaml for the why. + listen_addr: 127.0.0.1:5037 + require_mutation_approval: true + sandbox: + timeout: 180s + query_generation: + enabled: true + endpoint: http://localhost:5038 + timeout: 180s + execute_timeout: 180s + named_ops: + enabled: true diff --git a/demo/code-mode-connect/start.sh b/demo/code-mode-connect/start.sh new file mode 100755 index 0000000000..7379fa85d3 --- /dev/null +++ b/demo/code-mode-connect/start.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash + +set -Eeuo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +DEMO_DIR="$ROOT_DIR/demo" +CONNECT_DIR="$DEMO_DIR/code-mode-connect" +PID_FILE="/tmp/cosmo-code-mode-connect-demo.pids" +LOG_DIR="/tmp/cosmo-code-mode-connect-demo-logs" +GOCACHE_DIR="${GOCACHE:-/tmp/cosmo-code-mode-go-build-cache}" + +# Yoko project that owns the supergraph + plugin binaries. Required: +# YOKO_DIR=/path/to/yoko ./start.sh +YOKO_DIR="${YOKO_DIR:?YOKO_DIR is required (path to your yoko checkout)}" + +ROUTER_BIN="$ROOT_DIR/router/router" +ROUTER_CONFIG="$CONNECT_DIR/router-config.yaml" +YOKO_BIN="$DEMO_DIR/code-mode/yoko-mock/yoko-mock" + +append_pid() { + local name="$1" + local pid="$2" + printf '%s %s\n' "$name" "$pid" >> "$PID_FILE" +} + +kill_pid_file() { + if [ ! -f "$PID_FILE" ]; then + echo "No code-mode-connect demo PID file found at $PID_FILE" + return 0 + fi + + while read -r name pid; do + [ -n "${pid:-}" ] || continue + if kill -0 "$pid" 2>/dev/null; then + echo "Stopping $name pid=$pid" + kill "$pid" 2>/dev/null || true + fi + done < "$PID_FILE" + + sleep 1 + + while read -r name pid; do + [ -n "${pid:-}" ] || continue + if kill -0 "$pid" 2>/dev/null; then + echo "Force stopping $name pid=$pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done < "$PID_FILE" + + rm -f "$PID_FILE" +} + +cleanup() { + local status=$? + trap - EXIT INT TERM + kill_pid_file + exit "$status" +} + +wait_url() { + local name="$1" + local url="$2" + local timeout="${3:-90}" + local start + start="$(date +%s)" + + while true; do + if curl -fsS "$url" >/dev/null 2>&1; then + echo "$name is ready at $url" + return 0 + fi + + if [ "$(( $(date +%s) - start ))" -ge "$timeout" ]; then + echo "Timed out waiting for $name at $url" >&2 + echo "Logs are in $LOG_DIR" >&2 + return 1 + fi + + sleep 1 + done +} + +start_background_root() { + local name="$1" + shift + + echo "Starting $name" + # exec replaces the subshell with the binary, so $! is the binary's pid. + # Without exec, the subshell forks the binary and `--down` ends up signalling + # an already-exited subshell while the real process keeps running. + ( + cd "$ROOT_DIR" + exec "$@" + ) > "$LOG_DIR/$name.log" 2>&1 & + append_pid "$name" "$!" +} + +if [ "${1:-}" = "--down" ]; then + kill_pid_file + exit 0 +fi + +if [ ! -d "$YOKO_DIR" ]; then + echo "Yoko project directory not found: $YOKO_DIR" >&2 + echo "Set YOKO_DIR to override." >&2 + exit 1 +fi + +if [ ! -x "$ROUTER_BIN" ]; then + echo "Router binary not found or not executable: $ROUTER_BIN" >&2 + echo "Run: cd router && make build" >&2 + exit 1 +fi + +if [ ! -x "$YOKO_BIN" ]; then + echo "Yoko mock binary not found or not executable: $YOKO_BIN" >&2 + echo "Run: make -C demo/code-mode build-yoko" >&2 + exit 1 +fi + +if [ ! -f "$YOKO_DIR/config.json" ]; then + echo "Composed yoko supergraph not found: $YOKO_DIR/config.json" >&2 + echo "Run: cd $YOKO_DIR && make compose" >&2 + exit 1 +fi + +mkdir -p "$LOG_DIR" +mkdir -p "$GOCACHE_DIR" +rm -f "$PID_FILE" +trap cleanup EXIT INT TERM + +# yoko-mock listens on a different port than the regular code-mode-demo so the +# two demos can coexist (5028 vs 5038). +start_background_root yoko "$YOKO_BIN" -listen-addr localhost:5038 + +wait_url yoko http://localhost:5038/health + +echo "Starting router in foreground (CWD=$YOKO_DIR)" +( + cd "$YOKO_DIR" + exec "$ROUTER_BIN" -config "$ROUTER_CONFIG" +) & +router_pid="$!" +append_pid router "$router_pid" + +wait "$router_pid" diff --git a/demo/code-mode/.gitignore b/demo/code-mode/.gitignore new file mode 100644 index 0000000000..bc5fd710be --- /dev/null +++ b/demo/code-mode/.gitignore @@ -0,0 +1 @@ +mcp-stdio-proxy/mcp-stdio-proxy diff --git a/demo/code-mode/Makefile b/demo/code-mode/Makefile new file mode 100644 index 0000000000..1114f6ea7c --- /dev/null +++ b/demo/code-mode/Makefile @@ -0,0 +1,30 @@ +SHELL := bash +GOCACHE ?= /tmp/cosmo-code-mode-go-build-cache +wgc_env_arg = $(if $(wildcard ../cli/.env),--env-file ../cli/.env,) +wgc_router = pnpm dlx tsx $(wgc_env_arg) ../cli/src/index.ts router + +.PHONY: build-yoko build-stdio-proxy compose start down run-subgraphs + +build-yoko: + mkdir -p $(GOCACHE) + cd yoko-mock && GOCACHE=$(GOCACHE) go build -o yoko-mock . + +build-stdio-proxy: + mkdir -p $(GOCACHE) + cd mcp-stdio-proxy && GOCACHE=$(GOCACHE) go build -o mcp-stdio-proxy . + +compose: + cd .. && if [ -f ../cli/dist/src/index.js ]; then \ + DISABLE_UPDATE_CHECK=true node ../cli/dist/src/index.js router compose -i ./code-mode/graph.yaml -o ./code-mode/config.json; \ + else \ + DISABLE_UPDATE_CHECK=true TMPDIR=/tmp $(wgc_router) compose -i ./code-mode/graph.yaml -o ./code-mode/config.json; \ + fi + +start: + ./start.sh + +down: + ./start.sh --down + +run-subgraphs: + ./run_subgraphs_subset.sh diff --git a/demo/code-mode/README.md b/demo/code-mode/README.md new file mode 100644 index 0000000000..dee17d14a2 --- /dev/null +++ b/demo/code-mode/README.md @@ -0,0 +1,60 @@ +# Code Mode Demo + +This demo starts a small local federation (`employees`, `family`, `availability`, and `mood`), the Code Mode Yoko mock, and a local Cosmo Router with Code Mode and named operations enabled. + +## Prerequisites + +- Go (toolchain matching the repo `go.mod`). +- Node + `pnpm` (used by `wgc` to compose `demo/code-mode/graph.yaml`). +- The `codex` CLI on `PATH`, authenticated. + The Yoko mock shells out to `codex` for query generation; + without it, `code_mode_search_tools` cannot generate operations. + +## Quick start + +Run it from the repository root: + +```sh +make code-mode-demo +``` + +The root target builds `router/router`, builds `demo/code-mode/yoko-mock/yoko-mock`, builds `demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy` (used by stdio-only MCP clients like Claude Desktop), composes `demo/code-mode/graph.yaml` into `demo/code-mode/config.json`, then starts the demo processes. +The router stays in the foreground. + +Expected ports: + +- Router GraphQL: `http://localhost:3002/graphql` +- Code Mode MCP: `http://localhost:5027/mcp` +- Yoko mock: `http://localhost:5028` +- Employees subgraph: `http://localhost:4001/graphql` +- Family subgraph: `http://localhost:4002/graphql` +- Availability subgraph: `http://localhost:4007/graphql` +- Mood subgraph: `http://localhost:4008/graphql` + +## Tearing down + +To stop the demo, press Ctrl-C in the foreground terminal. +If anything is left behind (background subgraphs, yoko-mock), run: + +```sh +make code-mode-demo-down +``` + +The process logs for background services are written to `/tmp/cosmo-code-mode-demo-logs`. + +## Manual smoke check + +```sh +make code-mode-demo +curl -sS http://localhost:3002/graphql \ + -H 'content-type: application/json' \ + --data '{"query":"{ employees { id details { forename surname } } }"}' +``` + +## Other notes + +The subset runner is `demo/code-mode/run_subgraphs_subset.sh`. It starts only `employees`, `family`, `availability`, and `mood` via `npx concurrently` for a fast demo. `availability` and `mood` are included because the `employees` schema has federation references to fields owned by those subgraphs. The full demo `demo/run_subgraphs.sh` starts all subgraphs and is intentionally not used here. + +Client configuration for Code Mode MCP clients (Claude Code, Claude Desktop, Codex CLI) lives under `demo/code-mode/mcp-configs/` — see the README there. + +For the alternate "Connect" variant of this demo, which runs the same Code Mode router against an external `yoko` Connect supergraph instead of the local employees federation, see `demo/code-mode-connect/README.md`. diff --git a/demo/code-mode/graph.yaml b/demo/code-mode/graph.yaml new file mode 100644 index 0000000000..e95412def2 --- /dev/null +++ b/demo/code-mode/graph.yaml @@ -0,0 +1,18 @@ +version: 1 +subgraphs: + - name: employees + routing_url: http://localhost:4001/graphql + schema: + file: ../pkg/subgraphs/employees/subgraph/schema.graphqls + - name: family + routing_url: http://localhost:4002/graphql + schema: + file: ../pkg/subgraphs/family/subgraph/schema.graphqls + - name: availability + routing_url: http://localhost:4007/graphql + schema: + file: ../pkg/subgraphs/availability/subgraph/schema.graphqls + - name: mood + routing_url: http://localhost:4008/graphql + schema: + file: ../pkg/subgraphs/mood/subgraph/schema.graphqls diff --git a/demo/code-mode/mcp-configs/README.md b/demo/code-mode/mcp-configs/README.md new file mode 100644 index 0000000000..8fa517f95c --- /dev/null +++ b/demo/code-mode/mcp-configs/README.md @@ -0,0 +1,122 @@ +# Code Mode MCP Client Configs + +These snippets connect MCP clients to the Code Mode demo server at `http://localhost:5027/mcp`. +Start the demo first: + +```bash +make code-mode-demo +``` + +The configs are illustrative. +Real users can adapt paths, server names, timeouts, and auth settings for their local setup. +Do not add API keys or auth tokens to these files. + +## Claude Code + +`claude.mcp.json` matches Claude Code's `mcpServers` settings schema for Streamable HTTP: + +```json +{ + "mcpServers": { + "yoko": { + "type": "http", + "url": "http://localhost:5027/mcp" + } + } +} +``` + +Run with the config snippet directly: + +```bash +claude --mcp-config demo/code-mode/mcp-configs/claude.mcp.json --strict-mcp-config -p "$(cat demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt)" +``` + +Or install it into Claude Code project config: + +```bash +claude mcp add --scope project --transport http yoko http://localhost:5027/mcp +``` + +Claude Code writes project-scoped MCP servers to `.mcp.json`. +Use `--scope user` instead if you want the server available outside this checkout. + +## Claude Desktop + +Claude Desktop only speaks stdio, so it cannot connect to the demo's HTTP MCP endpoint directly. +The demo ships a tiny `mcp-stdio-proxy` binary that bridges Claude Desktop's stdio transport to the upstream HTTP server. +`make code-mode-demo` builds it at `demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy`. + +`claude.desktop.json` is the matching config: + +```json +{ + "mcpServers": { + "yoko": { + "command": "/ABSOLUTE/PATH/TO/cosmo/demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy", + "args": ["--upstream", "http://127.0.0.1:5027/mcp"] + } + } +} +``` + +Replace `/ABSOLUTE/PATH/TO/cosmo` with the absolute path to your checkout, then merge into `~/Library/Application Support/Claude/claude_desktop_config.json` (macOS) or `%APPDATA%\Claude\claude_desktop_config.json` (Windows) and restart Claude Desktop. + +## Codex CLI + +`codex.toml` matches Codex CLI's `~/.codex/config.toml` table format: + +```toml +[mcp_servers."yoko"] +url = "http://localhost:5027/mcp" +``` + +Install it by copying the table into `~/.codex/config.toml`, or add the same server with: + +```bash +codex mcp add yoko --url http://localhost:5027/mcp +``` + +Then run a prompt with your normal Codex config: + +```bash +codex exec --full-auto -- "$(cat demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt)" +``` + +To point one invocation at this snippet without editing your global config, pass equivalent config overrides: + +```bash +codex exec --full-auto \ + -c 'mcp_servers.yoko.url="http://localhost:5027/mcp"' \ + -- "$(cat demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt)" +``` + +Codex does not currently expose a direct `--config-file` flag for `codex.toml`. +For an isolated run against the checked-in snippet, place it at `$CODEX_HOME/config.toml` in a temporary directory: + +```bash +tmp_codex_home="$(mktemp -d)" +cp demo/code-mode/mcp-configs/codex.toml "$tmp_codex_home/config.toml" +CODEX_HOME="$tmp_codex_home" codex exec --full-auto -- "$(cat demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt)" +``` + +## Sample Prompts + +`sample-prompts/01-search-employees.txt` asks the client to call `code_mode_search_tools` with two prompts in one batch. +Expected output shape: the assistant should show the newly returned TypeScript `tools` declarations for the first-employee operation and the employee-by-id operation. + +`sample-prompts/02-execute-fetch.txt` asks the client to discover an employee-by-id operation and run `code_mode_run_js`. +Expected output shape: the assistant should show an `code_mode_run_js` result for employee `1`, returning the employee's `forename` and `surname`. + +`sample-prompts/03-multi-tool.txt` asks the client to discover two operations and compose them in a single `code_mode_run_js` program. +Expected output shape: the assistant should return both the first employee and that employee's family from one sandbox execution. + +`sample-prompts/04-mutation-not-approved.txt` asks the client to try an employee-tag mutation. +The historical prompt name mentions "not approved", but the demo config sets `require_mutation_approval: false` in `demo/code-mode/router-config.yaml`. +That means this prompt is not declined by operator approval in the default demo; it should run like a normal mutation if the mock can generate the operation. +Skip this prompt when you specifically need to demonstrate approval rejection. + +## Caveat + +The mock Yoko service shells out to the `codex` CLI for query generation. +The local `codex` CLI must be installed and authenticated before `code_mode_search_tools` can generate operations. diff --git a/demo/code-mode/mcp-configs/claude.desktop.json b/demo/code-mode/mcp-configs/claude.desktop.json new file mode 100644 index 0000000000..6297c6062d --- /dev/null +++ b/demo/code-mode/mcp-configs/claude.desktop.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "yoko": { + "command": "/ABSOLUTE/PATH/TO/cosmo/demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy", + "args": ["--upstream", "http://127.0.0.1:5027/mcp"] + } + } +} diff --git a/demo/code-mode/mcp-configs/claude.mcp.json b/demo/code-mode/mcp-configs/claude.mcp.json new file mode 100644 index 0000000000..f5dfa28e16 --- /dev/null +++ b/demo/code-mode/mcp-configs/claude.mcp.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "yoko": { + "type": "http", + "url": "http://localhost:5027/mcp" + } + } +} diff --git a/demo/code-mode/mcp-configs/codex.toml b/demo/code-mode/mcp-configs/codex.toml new file mode 100644 index 0000000000..03f1390f70 --- /dev/null +++ b/demo/code-mode/mcp-configs/codex.toml @@ -0,0 +1,2 @@ +[mcp_servers."yoko"] +url = "http://localhost:5027/mcp" diff --git a/demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt b/demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt new file mode 100644 index 0000000000..8777d17f1e --- /dev/null +++ b/demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt @@ -0,0 +1 @@ +Use the yoko MCP server. Call code_mode_search_tools with prompts that fetch the first employee and an employee by id. Then show me the TS that came back. diff --git a/demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt b/demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt new file mode 100644 index 0000000000..5084163314 --- /dev/null +++ b/demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt @@ -0,0 +1 @@ +Use yoko. Search for an op that fetches an employee by id, then write a code_mode_run_js program that fetches employee 1 and returns the forename + surname. diff --git a/demo/code-mode/mcp-configs/sample-prompts/03-multi-tool.txt b/demo/code-mode/mcp-configs/sample-prompts/03-multi-tool.txt new file mode 100644 index 0000000000..f4939ebb49 --- /dev/null +++ b/demo/code-mode/mcp-configs/sample-prompts/03-multi-tool.txt @@ -0,0 +1 @@ +Use yoko. Discover ops to (a) get the first employee and (b) get the family of a specific employee id; then run a single code_mode_run_js program that fetches the first employee, then their family, and returns both. diff --git a/demo/code-mode/mcp-configs/sample-prompts/04-mutation-not-approved.txt b/demo/code-mode/mcp-configs/sample-prompts/04-mutation-not-approved.txt new file mode 100644 index 0000000000..4f2e1c86c6 --- /dev/null +++ b/demo/code-mode/mcp-configs/sample-prompts/04-mutation-not-approved.txt @@ -0,0 +1 @@ +Use yoko. Try to update an employee tag and see what happens. diff --git a/demo/code-mode/mcp-stdio-proxy/go.mod b/demo/code-mode/mcp-stdio-proxy/go.mod new file mode 100644 index 0000000000..4720b36f3c --- /dev/null +++ b/demo/code-mode/mcp-stdio-proxy/go.mod @@ -0,0 +1,20 @@ +module github.com/wundergraph/cosmo/demo/code-mode/mcp-stdio-proxy + +go 1.25 + +require ( + github.com/modelcontextprotocol/go-sdk v1.4.1 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.34.0 // indirect + golang.org/x/sys v0.40.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/demo/code-mode/mcp-stdio-proxy/go.sum b/demo/code-mode/mcp-stdio-proxy/go.sum new file mode 100644 index 0000000000..e469bb22cf --- /dev/null +++ b/demo/code-mode/mcp-stdio-proxy/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc= +github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demo/code-mode/mcp-stdio-proxy/main.go b/demo/code-mode/mcp-stdio-proxy/main.go new file mode 100644 index 0000000000..03ab9df8aa --- /dev/null +++ b/demo/code-mode/mcp-stdio-proxy/main.go @@ -0,0 +1,162 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "sync/atomic" + "syscall" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const ( + // 127.0.0.1 (not localhost) so the Go HTTP client doesn't try ::1 first + // and get refused — the router binds IPv4 only. + defaultUpstreamURL = "http://127.0.0.1:5027/mcp" + proxyName = "yoko-stdio-proxy" + proxyVersion = "0.1.0" +) + +type proxyOptions struct { + upstreamURL string + transport mcp.Transport + httpClient *http.Client +} + +func main() { + log.SetOutput(os.Stderr) + + flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flags.SetOutput(os.Stderr) + upstreamURL := flags.String("upstream", defaultUpstreamURL, "HTTP MCP upstream URL") + flags.Usage = func() { + fmt.Fprintf(flags.Output(), "Usage: mcp-stdio-proxy --upstream \n") + flags.PrintDefaults() + } + if err := flags.Parse(os.Args[1:]); err != nil { + os.Exit(2) + } + if flags.NArg() != 0 { + fmt.Fprintln(os.Stderr, "mcp-stdio-proxy: unexpected positional arguments") + flags.Usage() + os.Exit(2) + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := runProxy(ctx, proxyOptions{ + upstreamURL: *upstreamURL, + transport: &mcp.StdioTransport{}, + }); err != nil && !errors.Is(err, context.Canceled) { + log.Fatalf("mcp-stdio-proxy: %v", err) + } +} + +func runProxy(ctx context.Context, opts proxyOptions) error { + if opts.upstreamURL == "" { + opts.upstreamURL = defaultUpstreamURL + } + if opts.transport == nil { + opts.transport = &mcp.StdioTransport{} + } + + var localSession atomic.Pointer[mcp.ServerSession] + upstreamClient := mcp.NewClient( + &mcp.Implementation{Name: proxyName, Version: proxyVersion}, + &mcp.ClientOptions{ + ElicitationHandler: func(ctx context.Context, req *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + ss := localSession.Load() + if ss == nil { + return nil, errors.New("no local session yet") + } + return ss.Elicit(ctx, req.Params) + }, + }, + ) + + upstreamSession, err := upstreamClient.Connect(ctx, &mcp.StreamableClientTransport{ + Endpoint: opts.upstreamURL, + HTTPClient: opts.httpClient, + }, nil) + if err != nil { + return fmt.Errorf("connect upstream %q failed: %w; is the demo running? try `make code-mode-demo`", opts.upstreamURL, err) + } + defer func() { + if err := upstreamSession.Close(); err != nil { + log.Printf("mcp-stdio-proxy: upstream close failed: %v", err) + } + }() + + toolsResp, err := upstreamSession.ListTools(ctx, &mcp.ListToolsParams{}) + if err != nil { + return fmt.Errorf("list upstream tools: %w", err) + } + resourcesResp, err := upstreamSession.ListResources(ctx, &mcp.ListResourcesParams{}) + if err != nil { + return fmt.Errorf("list upstream resources: %w", err) + } + + localServer := mcp.NewServer( + &mcp.Implementation{Name: "yoko (via stdio-proxy)", Version: proxyVersion}, + &mcp.ServerOptions{ + InitializedHandler: func(_ context.Context, req *mcp.InitializedRequest) { + localSession.Store(req.Session) + // Log the downstream client's declared capabilities so we know + // whether elicitation forwarding will work end to end. + if p := req.Session.InitializeParams(); p != nil { + hasElicit := p.Capabilities != nil && p.Capabilities.Elicitation != nil + name := "" + ver := "" + if p.ClientInfo != nil { + name = p.ClientInfo.Name + ver = p.ClientInfo.Version + } + log.Printf("mcp-stdio-proxy: downstream initialized name=%q version=%q elicitation=%v", name, ver, hasElicit) + } + }, + }, + ) + + for _, upstreamTool := range toolsResp.Tools { + tool := *upstreamTool + if tool.InputSchema == nil { + tool.InputSchema = map[string]any{"type": "object"} + } + localServer.AddTool(&tool, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + result, err := upstreamSession.CallTool(ctx, &mcp.CallToolParams{ + Meta: req.Params.Meta, + Name: req.Params.Name, + Arguments: req.Params.Arguments, + }) + if err != nil { + var errResult mcp.CallToolResult + errResult.SetError(fmt.Errorf("upstream tool %q failed: %w", req.Params.Name, err)) + return &errResult, nil + } + return result, nil + }) + } + + for _, upstreamResource := range resourcesResp.Resources { + resource := *upstreamResource + localServer.AddResource(&resource, func(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + result, err := upstreamSession.ReadResource(ctx, req.Params) + if err != nil { + return nil, fmt.Errorf("upstream resource %q failed: %w", req.Params.URI, err) + } + return result, nil + }) + } + + if err := localServer.Run(ctx, opts.transport); err != nil { + return err + } + return nil +} diff --git a/demo/code-mode/mcp-stdio-proxy/main_test.go b/demo/code-mode/mcp-stdio-proxy/main_test.go new file mode 100644 index 0000000000..0086a70b2b --- /dev/null +++ b/demo/code-mode/mcp-stdio-proxy/main_test.go @@ -0,0 +1,298 @@ +package main + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProxyMirrorsUpstreamSurfaceAndForwardsElicitation(t *testing.T) { + tests := []struct { + name string + run func(context.Context, *testing.T, *mcp.ClientSession) + }{ + { + name: "list tools", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.ListTools(ctx, &mcp.ListToolsParams{}) + require.NoError(t, err) + assert.Equal(t, &mcp.ListToolsResult{ + Tools: []*mcp.Tool{ + { + Name: "ask", + Description: "Ask for approval.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": false, + }, + }, + { + Name: "echo", + Description: "Echo the input.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + }, resp) + }, + }, + { + name: "call echo", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{"x": 1}, + }) + require.NoError(t, err) + assert.Equal(t, &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"x":1}`}}, + StructuredContent: map[string]any{"x": float64(1)}, + }, resp) + }, + }, + { + name: "list resources", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, &mcp.ListResourcesResult{ + Resources: []*mcp.Resource{ + { + URI: "demo://hello", + Name: "hello", + Title: "Hello", + MIMEType: "text/plain", + }, + }, + }, resp) + }, + }, + { + name: "read resource", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.ReadResource(ctx, &mcp.ReadResourceParams{URI: "demo://hello"}) + require.NoError(t, err) + assert.Equal(t, &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: "demo://hello", + MIMEType: "text/plain", + Text: "hi", + }, + }, + }, resp) + }, + }, + { + name: "call ask forwards elicitation", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "ask", + Arguments: map[string]any{}, + }) + require.NoError(t, err) + assert.Equal(t, &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"approved":true}`}}, + StructuredContent: map[string]any{"approved": true}, + }, resp) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + upstream := newTestUpstream(t) + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + errCh := make(chan error, 1) + go func() { + errCh <- runProxy(ctx, proxyOptions{ + upstreamURL: upstream.URL, + transport: serverTransport, + httpClient: upstream.Client(), + }) + }() + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "0.1.0"}, &mcp.ClientOptions{ + ElicitationHandler: func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{ + Action: "accept", + Content: map[string]any{"approved": true}, + }, nil + }, + }) + session, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + defer func() { + require.NoError(t, session.Close()) + err := <-errCh + if !errors.Is(err, context.Canceled) { + require.NoError(t, err) + } + }() + + tt.run(ctx, t, session) + }) + } +} + +func newTestUpstream(t *testing.T) *httptest.Server { + t.Helper() + + server := mcp.NewServer(&mcp.Implementation{Name: "test-upstream", Version: "0.1.0"}, nil) + server.AddTool(&mcp.Tool{ + Name: "echo", + Description: "Echo the input.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: string(req.Params.Arguments)}}, + StructuredContent: req.Params.Arguments, + }, nil + }) + server.AddTool(&mcp.Tool{ + Name: "ask", + Description: "Ask for approval.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": false, + }, + }, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + result, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ + Message: "Approve mutation?", + RequestedSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "approved": map[string]any{"type": "boolean"}, + }, + }, + }) + if err != nil { + return nil, err + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"approved":true}`}}, + StructuredContent: result.Content, + }, nil + }) + server.AddResource(&mcp.Resource{ + URI: "demo://hello", + Name: "hello", + Title: "Hello", + MIMEType: "text/plain", + }, func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: "demo://hello", + MIMEType: "text/plain", + Text: "hi", + }, + }, + }, nil + }) + + mux := http.NewServeMux() + mux.Handle("/", mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil)) + + listener := newPipeListener() + t.Cleanup(func() { + require.NoError(t, listener.Close()) + }) + + httpServer := &httptest.Server{ + Listener: listener, + Config: &http.Server{ + Handler: mux, + BaseContext: func(net.Listener) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return ctx + }, + }, + } + httpServer.Start() + t.Cleanup(httpServer.Close) + httpServer.Client().Transport = &http.Transport{ + DialContext: listener.DialContext, + } + return httpServer +} + +type pipeListener struct { + conns chan net.Conn + done chan struct{} +} + +func newPipeListener() *pipeListener { + return &pipeListener{ + conns: make(chan net.Conn), + done: make(chan struct{}), + } +} + +func (l *pipeListener) Accept() (net.Conn, error) { + select { + case conn := <-l.conns: + return conn, nil + case <-l.done: + return nil, net.ErrClosed + } +} + +func (l *pipeListener) Close() error { + select { + case <-l.done: + default: + close(l.done) + } + return nil +} + +func (l *pipeListener) Addr() net.Addr { + return pipeAddr("pipe") +} + +func (l *pipeListener) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { + serverConn, clientConn := net.Pipe() + select { + case l.conns <- serverConn: + return clientConn, nil + case <-ctx.Done(): + _ = serverConn.Close() + _ = clientConn.Close() + return nil, ctx.Err() + case <-l.done: + _ = serverConn.Close() + _ = clientConn.Close() + return nil, net.ErrClosed + } +} + +type pipeAddr string + +func (a pipeAddr) Network() string { + return "pipe" +} + +func (a pipeAddr) String() string { + return string(a) +} diff --git a/demo/code-mode/router-config.yaml b/demo/code-mode/router-config.yaml new file mode 100644 index 0000000000..01fac390a2 --- /dev/null +++ b/demo/code-mode/router-config.yaml @@ -0,0 +1,56 @@ +version: "1" + +listen_addr: "localhost:3002" +graphql_path: "/graphql" +playground_enabled: false +json_log: false +log_level: info +dev_mode: true +router_registration: false + +execution_config: + file: + path: "demo/code-mode/config.json" + watch: false + +graphql_metrics: + enabled: false + +telemetry: + tracing: + enabled: false + metrics: + otlp: + enabled: false + prometheus: + enabled: false + +mcp: + enabled: false + graph_name: code-mode-demo + router_url: http://localhost:3002/graphql + session: + stateless: false + code_mode: + enabled: true + server: + # Bind IPv4 explicitly. On macOS, "localhost:5027" binds only IPv4 + # but clients that resolve "localhost" to ::1 first (Go's resolver, + # the MCP stdio proxy) get refused — point them at 127.0.0.1 directly + # in start.sh and the proxy defaults. + listen_addr: 127.0.0.1:5027 + require_mutation_approval: true + # Sandbox wall-clock cap. Default is 5s (plan §13), which is fine for + # compute-only agent code but too short whenever the host blocks the JS + # thread on an interactive elicitation. Bump to 180s so a human can review + # a mutation prompt without the qjs runtime context expiring under us. + sandbox: + timeout: 180s + query_generation: + enabled: true + endpoint: http://localhost:5028 + timeout: 180s + execute_timeout: 180s + named_ops: + enabled: true + # storage.provider_id intentionally unset -> in-memory backend (the default) diff --git a/demo/code-mode/run_subgraphs_subset.sh b/demo/code-mode/run_subgraphs_subset.sh new file mode 100755 index 0000000000..23e2c20ec3 --- /dev/null +++ b/demo/code-mode/run_subgraphs_subset.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +set -eu + +cd "$(dirname "$0")/.." +GOCACHE="${GOCACHE:-/tmp/cosmo-code-mode-go-build-cache}" +mkdir -p "$GOCACHE" + +npx concurrently --kill-others \ + "GOCACHE=$GOCACHE PORT=4001 go run ./cmd/employees" \ + "GOCACHE=$GOCACHE PORT=4002 go run ./cmd/family" \ + "GOCACHE=$GOCACHE PORT=4007 go run ./cmd/availability" \ + "GOCACHE=$GOCACHE PORT=4008 go run ./cmd/mood" diff --git a/demo/code-mode/start.sh b/demo/code-mode/start.sh new file mode 100755 index 0000000000..c079e1d1db --- /dev/null +++ b/demo/code-mode/start.sh @@ -0,0 +1,149 @@ +#!/usr/bin/env bash + +set -Eeuo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +DEMO_DIR="$ROOT_DIR/demo" +CODE_MODE_DIR="$DEMO_DIR/code-mode" +PID_FILE="/tmp/cosmo-code-mode-demo.pids" +LOG_DIR="/tmp/cosmo-code-mode-demo-logs" +GOCACHE_DIR="${GOCACHE:-/tmp/cosmo-code-mode-go-build-cache}" + +ROUTER_BIN="$ROOT_DIR/router/router" +ROUTER_CONFIG="$CODE_MODE_DIR/router-config.yaml" +YOKO_BIN="$CODE_MODE_DIR/yoko-mock/yoko-mock" + +append_pid() { + local name="$1" + local pid="$2" + printf '%s %s\n' "$name" "$pid" >> "$PID_FILE" +} + +kill_pid_file() { + if [ ! -f "$PID_FILE" ]; then + echo "No Code Mode demo PID file found at $PID_FILE" + return 0 + fi + + while read -r name pid; do + [ -n "${pid:-}" ] || continue + if kill -0 "$pid" 2>/dev/null; then + echo "Stopping $name pid=$pid" + kill "$pid" 2>/dev/null || true + fi + done < "$PID_FILE" + + sleep 1 + + while read -r name pid; do + [ -n "${pid:-}" ] || continue + if kill -0 "$pid" 2>/dev/null; then + echo "Force stopping $name pid=$pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done < "$PID_FILE" + + rm -f "$PID_FILE" +} + +cleanup() { + local status=$? + trap - EXIT INT TERM + kill_pid_file + exit "$status" +} + +wait_url() { + local name="$1" + local url="$2" + local timeout="${3:-90}" + local start + start="$(date +%s)" + + while true; do + if curl -fsS "$url" >/dev/null 2>&1; then + echo "$name is ready at $url" + return 0 + fi + + if [ "$(( $(date +%s) - start ))" -ge "$timeout" ]; then + echo "Timed out waiting for $name at $url" >&2 + echo "Logs are in $LOG_DIR" >&2 + return 1 + fi + + sleep 1 + done +} + +start_background() { + local name="$1" + local cwd="$2" + shift 2 + + echo "Starting $name" + ( + cd "$cwd" + "$@" + ) > "$LOG_DIR/$name.log" 2>&1 & + append_pid "$name" "$!" +} + +start_background_root() { + local name="$1" + shift + + echo "Starting $name" + ( + cd "$ROOT_DIR" + "$@" + ) > "$LOG_DIR/$name.log" 2>&1 & + append_pid "$name" "$!" +} + +if [ "${1:-}" = "--down" ]; then + kill_pid_file + exit 0 +fi + +if [ ! -x "$ROUTER_BIN" ]; then + echo "Router binary not found or not executable: $ROUTER_BIN" >&2 + echo "Run: cd router && make build" >&2 + exit 1 +fi + +if [ ! -x "$YOKO_BIN" ]; then + echo "Yoko mock binary not found or not executable: $YOKO_BIN" >&2 + echo "Run: cd demo/code-mode/yoko-mock && go build -o yoko-mock ." >&2 + exit 1 +fi + +if [ ! -f "$CODE_MODE_DIR/config.json" ]; then + echo "Composed router config not found: $CODE_MODE_DIR/config.json" >&2 + echo "Run: make -C demo/code-mode compose" >&2 + exit 1 +fi + +mkdir -p "$LOG_DIR" +mkdir -p "$GOCACHE_DIR" +rm -f "$PID_FILE" +trap cleanup EXIT INT TERM + +start_background employees "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4001 go run ./cmd/employees +start_background family "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4002 go run ./cmd/family +start_background availability "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4007 go run ./cmd/availability +start_background mood "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4008 go run ./cmd/mood +start_background_root yoko "$YOKO_BIN" -listen-addr localhost:5028 + +wait_url employees http://localhost:4001/ +wait_url family http://localhost:4002/ +wait_url availability http://localhost:4007/ +wait_url mood http://localhost:4008/ +wait_url yoko http://localhost:5028/health + +echo "Starting router in foreground" +"$ROUTER_BIN" -config "$ROUTER_CONFIG" & +router_pid="$!" +append_pid router "$router_pid" + +wait "$router_pid" diff --git a/demo/code-mode/yoko-mock/.gitignore b/demo/code-mode/yoko-mock/.gitignore new file mode 100644 index 0000000000..f3e6959ad6 --- /dev/null +++ b/demo/code-mode/yoko-mock/.gitignore @@ -0,0 +1,3 @@ +yoko-mock +bench +cmd/bench/bench diff --git a/demo/code-mode/yoko-mock/README.md b/demo/code-mode/yoko-mock/README.md new file mode 100644 index 0000000000..c688b43f6d --- /dev/null +++ b/demo/code-mode/yoko-mock/README.md @@ -0,0 +1,46 @@ +# Yoko Mock + +This is a demo implementation of the Code Mode `YokoService` Connect RPC. It indexes a supergraph SDL in memory, then shells out to the host `codex` CLI to generate GraphQL operations for natural-language prompts. + +## Run + +From the repository root: + +```sh +go run ./demo/code-mode/yoko-mock --listen-addr :5028 +``` + +Flags: + +- `--listen-addr` defaults to `localhost:5028`. +- `--codex-bin` defaults to `codex` and is resolved through `PATH` unless an absolute path is supplied. +- `--codex-timeout` defaults to `60s`. + +The service calls: + +```sh +codex exec --full-auto --skip-git-repo-check - +``` + +with the generated prompt on stdin. The host must have a real `codex` CLI installed and authenticated. + +## Behavior + +- `POST /wundergraph.cosmo.code_mode.yoko.v1.YokoService/Index` stores the SDL in memory and returns `schema_id`, the first 16 hex characters of `sha256(schema_sdl)`. +- `POST /wundergraph.cosmo.code_mode.yoko.v1.YokoService/Search` looks up `schema_id`, invokes `codex`, parses its stdout as a JSON array, and returns the generated operations without local deduping or ranking. +- `/health` returns `200 OK`. + +If `Search` receives an unknown `schema_id`, it returns Connect `NOT_FOUND`; the router client is expected to re-index and retry once. If `codex` returns invalid JSON, the service logs a warning, writes the raw stdout to `/tmp/yoko-mock-last-bad-output.log`, and returns Connect `INTERNAL`. + +Expected codex stdout: + +```json +[ + { + "name": "getViewer", + "body": "query getViewer { viewer { id } }", + "kind": "query", + "description": "Fetches the current viewer." + } +] +``` diff --git a/demo/code-mode/yoko-mock/go.mod b/demo/code-mode/yoko-mock/go.mod new file mode 100644 index 0000000000..807baae723 --- /dev/null +++ b/demo/code-mode/yoko-mock/go.mod @@ -0,0 +1,22 @@ +module github.com/wundergraph/cosmo/demo/code-mode/yoko-mock + +go 1.25.0 + +require ( + connectrpc.com/connect v1.19.1 + github.com/dgraph-io/ristretto/v2 v2.4.0 + github.com/stretchr/testify v1.11.1 + github.com/wundergraph/cosmo/router v0.0.0 + google.golang.org/protobuf v1.36.10 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.40.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace github.com/wundergraph/cosmo/router => ../../../router diff --git a/demo/code-mode/yoko-mock/go.sum b/demo/code-mode/yoko-mock/go.sum new file mode 100644 index 0000000000..e60cb737e8 --- /dev/null +++ b/demo/code-mode/yoko-mock/go.sum @@ -0,0 +1,26 @@ +connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= +connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/ristretto/v2 v2.4.0 h1:I/w09yLjhdcVD2QV192UJcq8dPBaAJb9pOuMyNy0XlU= +github.com/dgraph-io/ristretto/v2 v2.4.0/go.mod h1:0KsrXtXvnv0EqnzyowllbVJB8yBonswa2lTCK2gGo9E= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demo/code-mode/yoko-mock/main.go b/demo/code-mode/yoko-mock/main.go new file mode 100644 index 0000000000..3a412fe48f --- /dev/null +++ b/demo/code-mode/yoko-mock/main.go @@ -0,0 +1,583 @@ +package main + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + "connectrpc.com/connect" + "github.com/dgraph-io/ristretto/v2" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" +) + +const badOutputPath = "/tmp/yoko-mock-last-bad-output.log" + +type yokoService struct { + codexBin string + codexTimeout time.Duration + codexReasoningEffort string + rotateAfter int // re-warm the codex session after this many Search calls; 0 disables + + // promptCache memoizes (schemaID, prompt) -> GeneratedOperation. A cache + // hit lets us skip codex entirely for that prompt. nil if the cache is + // disabled (size <= 0). + promptCache *ristretto.Cache[string, *yokov1.GeneratedOperation] + + mu sync.RWMutex + schemas map[string]*schemaEntry +} + +// schemaEntry records the on-disk schema dir (so codex can read schema.graphql +// once at Index time) plus the codex session id created during that pre-warm. +// Search uses `codex exec resume ` to reuse the already-loaded +// schema context instead of re-reading it on every call. +// +// To bound session-file growth, every yokoService.rotateAfter Search calls a +// background goroutine pre-warms a fresh session and atomically swaps the +// sessionID. searchCount tracks calls; rotationActive ensures only one +// rotation runs at a time. +type schemaEntry struct { + dir string + + mu sync.RWMutex + sessionID string + + searchCount atomic.Int64 + rotationActive atomic.Bool +} + +type codexOperation struct { + Name string `json:"name"` + Body string `json:"body"` + Kind string `json:"kind"` + Description string `json:"description"` +} + +type codexOutput struct { + Operations []codexOperation `json:"operations"` +} + +func main() { + listenAddr := flag.String("listen-addr", "localhost:5028", "address for the Yoko mock HTTP server") + codexBin := flag.String("codex-bin", "codex", "codex CLI binary path or name") + codexTimeout := flag.Duration("codex-timeout", 60*time.Second, "codex CLI timeout") + codexReasoningEffort := flag.String("codex-reasoning-effort", "low", "codex reasoning effort: minimal | low | medium | high") + codexRotateAfter := flag.Int("codex-rotate-after", 20, "re-warm the codex session after N Search calls (0 = disable rotation)") + promptCacheSize := flag.Int("prompt-cache-size", 1000, "max items in the (schema_id, prompt) -> operation cache (0 = disable)") + flag.Parse() + + svc, err := newYokoService(*codexBin, *codexTimeout, *codexReasoningEffort, *codexRotateAfter, *promptCacheSize) + if err != nil { + log.Fatalf("create yoko service: %v", err) + } + defer svc.Close() + server := &http.Server{ + Addr: *listenAddr, + Handler: newHTTPMux(svc), + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + errCh := make(chan error, 1) + go func() { + log.Printf("yoko mock listening addr=%s codex_bin=%s codex_timeout=%s reasoning_effort=%s", *listenAddr, *codexBin, codexTimeout.String(), *codexReasoningEffort) + errCh <- server.ListenAndServe() + }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + log.Fatalf("server shutdown failed: %v", err) + } + case err := <-errCh: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("server failed: %v", err) + } + } +} + +func newYokoService(codexBin string, codexTimeout time.Duration, reasoningEffort string, rotateAfter, promptCacheSize int) (*yokoService, error) { + svc := &yokoService{ + codexBin: codexBin, + codexTimeout: codexTimeout, + codexReasoningEffort: reasoningEffort, + rotateAfter: rotateAfter, + schemas: make(map[string]*schemaEntry), + } + if promptCacheSize > 0 { + // Each cache entry has cost 1, so MaxCost is the item ceiling. + // NumCounters is conventionally 10× expected items. + cache, err := ristretto.NewCache(&ristretto.Config[string, *yokov1.GeneratedOperation]{ + NumCounters: int64(promptCacheSize) * 10, + MaxCost: int64(promptCacheSize), + BufferItems: 64, + }) + if err != nil { + return nil, fmt.Errorf("create prompt cache: %w", err) + } + svc.promptCache = cache + } + return svc, nil +} + +func newHTTPMux(svc *yokoService) *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK\n")) + }) + path, handler := yokov1connect.NewYokoServiceHandler(svc) + mux.Handle(path, handler) + return mux +} + +func (s *yokoService) Index(ctx context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + schemaSDL := req.Msg.GetSchemaSdl() + id := schemaID(schemaSDL) + + s.mu.Lock() + if existing, ok := s.schemas[id]; ok { + s.mu.Unlock() + existing.mu.RLock() + existingSession := existing.sessionID + existing.mu.RUnlock() + log.Printf("Index schema_id=%s reused dir=%s session_id=%s", id, existing.dir, existingSession) + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + s.mu.Unlock() + + dir, err := os.MkdirTemp("", "yoko-schema-"+id+"-") + if err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create schema temp dir: %w", err)) + } + if err := os.WriteFile(filepath.Join(dir, "schema.graphql"), []byte(schemaSDL), 0o600); err != nil { + _ = os.RemoveAll(dir) + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write schema.graphql: %w", err)) + } + + sessionID, err := s.runCodexIndex(ctx, dir) + if err != nil { + _ = os.RemoveAll(dir) + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("codex pre-warm: %w", err)) + } + + entry := &schemaEntry{dir: dir, sessionID: sessionID} + s.mu.Lock() + s.schemas[id] = entry + s.mu.Unlock() + + log.Printf("Index schema_id=%s schema_sdl_size=%d schema_dir=%s session_id=%s rotate_after=%d", id, len(schemaSDL), dir, sessionID, s.rotateAfter) + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil +} + +// Close removes every per-schema temp dir created by Index. Safe to call +// multiple times; subsequent calls are no-ops. Codex session rollout files +// live under ~/.codex/sessions/ and are intentionally left in place — they +// belong to the user's codex install. +func (s *yokoService) Close() { + s.mu.Lock() + defer s.mu.Unlock() + for id, entry := range s.schemas { + if err := os.RemoveAll(entry.dir); err != nil { + log.Printf("Close schema_id=%s dir=%s err=%v", id, entry.dir, err) + continue + } + log.Printf("Close schema_id=%s dir=%s removed", id, entry.dir) + } + s.schemas = nil + if s.promptCache != nil { + s.promptCache.Close() + s.promptCache = nil + } +} + +func (s *yokoService) Search(ctx context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + schemaID := req.Msg.GetSchemaId() + prompts := req.Msg.GetPrompts() + + s.mu.RLock() + entry, ok := s.schemas[schemaID] + s.mu.RUnlock() + if !ok { + log.Printf("Search schema_id=%s prompt_count=%d not_found=true", schemaID, len(prompts)) + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("schema_id %q not found; call Index before Search", schemaID)) + } + + // Bump per-session call counter; if we crossed the threshold and no + // rotation is in flight, kick one off in the background. The CAS makes + // the trigger one-shot until rotation completes and clears the flag. + count := entry.searchCount.Add(1) + if s.rotateAfter > 0 && count >= int64(s.rotateAfter) && entry.rotationActive.CompareAndSwap(false, true) { + go s.rotateSession(schemaID, entry, count) + } + + // Cache lookup: collect cached ops in their original positions, batch + // only the misses to codex. + results := make([]*yokov1.GeneratedOperation, len(prompts)) + missing := make([]string, 0, len(prompts)) + missingIdx := make([]int, 0, len(prompts)) + hits := 0 + for i, p := range prompts { + if op, ok := s.cacheGet(schemaID, p); ok { + results[i] = op + hits++ + } else { + missing = append(missing, p) + missingIdx = append(missingIdx, i) + } + } + + if len(missing) == 0 { + log.Printf("Search schema_id=%s prompt_count=%d cache_hits=%d cache_misses=0 codex_skipped=true", schemaID, len(prompts), hits) + return connect.NewResponse(&yokov1.SearchResponse{Operations: filterNonNil(results)}), nil + } + + entry.mu.RLock() + sessionID := entry.sessionID + entry.mu.RUnlock() + + prompt := buildCodexPrompt(missing) + stdout, err := s.runCodexResume(ctx, sessionID, prompt) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, err) + } + + generated, err := parseCodexOperations(stdout) + if err != nil { + if writeErr := os.WriteFile(badOutputPath, stdout, 0o600); writeErr != nil { + log.Printf("warning: failed to write bad codex output path=%s err=%v", badOutputPath, writeErr) + } + log.Printf("warning: codex output was not valid JSON schema_id=%s prompt_count=%d stdout_size=%d err=%v", schemaID, len(missing), len(stdout), err) + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("codex output was not valid JSON; raw output saved to %s", badOutputPath)) + } + + // Pair generated ops back into the original prompt slots and cache the + // successful ones. We trust order: codex was instructed to return one + // operation per missing prompt in the same order. If codex returned + // fewer ops than asked, the trailing prompts have no slot filled (and + // don't get cached). + for k, idx := range missingIdx { + if k >= len(generated) { + break + } + op := generated[k] + if op == nil || op.GetBody() == "" { + // Failed prompt — don't cache, leave slot nil (filtered out below). + continue + } + results[idx] = op + s.cachePut(schemaID, missing[k], op) + } + + log.Printf("Search schema_id=%s prompt_count=%d cache_hits=%d cache_misses=%d codex_stdout_size=%d parsed_op_count=%d", schemaID, len(prompts), hits, len(missing), len(stdout), len(generated)) + return connect.NewResponse(&yokov1.SearchResponse{Operations: filterNonNil(results)}), nil +} + +func filterNonNil(ops []*yokov1.GeneratedOperation) []*yokov1.GeneratedOperation { + out := ops[:0] + for _, op := range ops { + if op != nil { + out = append(out, op) + } + } + return out +} + +// cacheKey returns the (schema_id, prompt) lookup key. We include schema_id +// so the same prompt against a different supergraph doesn't return a stale +// operation. +func cacheKey(schemaID, prompt string) string { + return schemaID + "\x00" + prompt +} + +func (s *yokoService) cacheGet(schemaID, prompt string) (*yokov1.GeneratedOperation, bool) { + if s.promptCache == nil { + return nil, false + } + return s.promptCache.Get(cacheKey(schemaID, prompt)) +} + +func (s *yokoService) cachePut(schemaID, prompt string, op *yokov1.GeneratedOperation) { + if s.promptCache == nil { + return + } + s.promptCache.Set(cacheKey(schemaID, prompt), op, 1) +} + +// rotateSession is launched in a goroutine when Search counts cross +// rotateAfter. It pre-warms a fresh codex session against the same on-disk +// schema, then atomically swaps in the new sessionID and resets the search +// counter. While rotation is running, concurrent Search calls keep using the +// old sessionID — they just don't trigger a second rotation. +func (s *yokoService) rotateSession(schemaID string, entry *schemaEntry, triggerCount int64) { + start := time.Now() + log.Printf("rotation kickoff schema_id=%s trigger_count=%d", schemaID, triggerCount) + + ctx, cancel := context.WithTimeout(context.Background(), s.codexTimeout) + defer cancel() + + newSessionID, err := s.runCodexIndex(ctx, entry.dir) + if err != nil { + log.Printf("rotation failed schema_id=%s elapsed=%s err=%v", schemaID, time.Since(start).Round(time.Millisecond), err) + entry.rotationActive.Store(false) + return + } + + entry.mu.Lock() + oldSessionID := entry.sessionID + entry.sessionID = newSessionID + entry.mu.Unlock() + + // Reset count BEFORE clearing rotationActive so a Search arriving in this + // gap can't trigger a second rotation on a freshly-rotated session. + entry.searchCount.Store(0) + entry.rotationActive.Store(false) + + log.Printf("rotation complete schema_id=%s old_session=%s new_session=%s elapsed=%s", schemaID, oldSessionID, newSessionID, time.Since(start).Round(time.Millisecond)) +} + +func schemaID(schemaSDL string) string { + sum := sha256.Sum256([]byte(schemaSDL)) + return fmt.Sprintf("%x", sum)[:16] +} + +const indexCodexPrompt = `Read the COMPLETE content of the file ./schema.graphql in your current working directory using your file-reading tool. Read the ENTIRE file (it is approximately 17KB and 824 lines) — do not truncate, do not skim, do not read only a portion. The file is a federated GraphQL supergraph SDL. + +Once the full schema is loaded into your context, output exactly this JSON object and nothing else: + +{"ready":true} + +Do not include preamble, prose, markdown fences, or commentary.` + +func buildCodexPrompt(prompts []string) string { + var b strings.Builder + b.WriteString("You already loaded the federated GraphQL supergraph SDL from\n") + b.WriteString("./schema.graphql earlier in this session. Use it as the source of\n") + b.WriteString("truth — do not re-read the file.\n\n") + b.WriteString("For each user prompt below, generate ONE corresponding GraphQL\n") + b.WriteString("operation (query or mutation) that fulfills the prompt against\n") + b.WriteString("the schema. Return one operation per prompt, in the same order.\n\n") + b.WriteString("PARAMETERIZATION REQUIREMENT (load-bearing):\n") + b.WriteString("Whenever an argument's value depends on the caller's intent (an id,\n") + b.WriteString("a filter, a name, a tag, a limit, etc.), you MUST declare a GraphQL\n") + b.WriteString("variable for it and reference it via $varName. NEVER inline a literal\n") + b.WriteString("for caller-controlled arguments.\n") + b.WriteString("Example query: query employeeByID($id: Int!) { employee(id: $id) { id details { forename surname } } }\n") + b.WriteString("Example mutation: mutation updateEmployeeTag($id: Int!, $tag: String!) { updateEmployeeTag(id: $id, tag: $tag) { id tag } }\n") + b.WriteString("Only inline a literal when the argument is genuinely fixed by the prompt\n") + b.WriteString("(for example, 'list ALL employees' might pass no args at all). Variable\n") + b.WriteString("types must match the schema, including non-null bangs.\n\n") + b.WriteString("OUTPUT FORMAT (strict, machine-parsed):\n") + b.WriteString("- Output a single JSON object with one key: \"operations\" (array).\n") + b.WriteString("- Each operation has keys: name (camelCase), body (operation\n") + b.WriteString(" source text starting with 'query (...)' or\n") + b.WriteString(" 'mutation (...)' when variables are declared, or\n") + b.WriteString(" 'query { ... }' / 'mutation { ... }' when truly\n") + b.WriteString(" variable-free), kind ('query' or 'mutation'), description\n") + b.WriteString(" (one short sentence).\n") + b.WriteString("- operations.length MUST equal the number of user prompts below,\n") + b.WriteString(" in the same order.\n") + b.WriteString("- No prose, no preamble, no markdown fences.\n\n") + b.WriteString("USER PROMPTS:\n") + for _, prompt := range prompts { + b.WriteString("- ") + b.WriteString(prompt) + b.WriteByte('\n') + } + return b.String() +} + +// runCodexIndex performs the one-time pre-warm: codex reads schema.graphql in +// schemaDir and a session is started. The session id (UUID) is parsed from +// codex's first JSONL event and returned so subsequent Search calls can resume +// the same session. +func (s *yokoService) runCodexIndex(ctx context.Context, schemaDir string) (string, error) { + cmdCtx, cancel := context.WithTimeout(ctx, s.codexTimeout) + defer cancel() + + args := []string{ + "exec", + "--json", + "-s", "read-only", + "--skip-git-repo-check", + "--ignore-user-config", + "--ignore-rules", + "-c", "model_reasoning_effort=" + s.codexReasoningEffort, + "-c", "approval_policy=never", + "-", + } + + start := time.Now() + cmd := exec.CommandContext(cmdCtx, s.codexBin, args...) + cmd.Dir = schemaDir + cmd.Stdin = strings.NewReader(indexCodexPrompt) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + stdout, err := cmd.Output() + elapsed := time.Since(start) + exitCode := 0 + if err != nil { + exitCode = -1 + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode = exitErr.ExitCode() + } + } + log.Printf("codex index duration=%s exit_code=%d stdout_prefix=%q stderr_prefix=%q", elapsed.Round(time.Millisecond), exitCode, prefix(stdout, 160), prefix(stderr.Bytes(), 160)) + + if cmdCtx.Err() != nil { + return "", fmt.Errorf("codex index timed out after %s", s.codexTimeout) + } + if err != nil { + return "", fmt.Errorf("codex index failed exit_code=%d stderr=%q: %w", exitCode, prefix(stderr.Bytes(), 300), err) + } + + return parseThreadID(stdout) +} + +// runCodexResume resumes the previously-warmed session and runs the user +// prompts. The agent's last message (a JSON object of operations) is captured +// via `--output-last-message` and returned for parsing. +func (s *yokoService) runCodexResume(ctx context.Context, sessionID, prompt string) ([]byte, error) { + cmdCtx, cancel := context.WithTimeout(ctx, s.codexTimeout) + defer cancel() + + outFile, err := os.CreateTemp("", "yoko-search-out-*.txt") + if err != nil { + return nil, fmt.Errorf("create output temp file: %w", err) + } + outPath := outFile.Name() + _ = outFile.Close() + defer os.Remove(outPath) + + args := []string{ + "exec", "resume", sessionID, + "-o", outPath, + "--skip-git-repo-check", + "--ignore-user-config", + "--ignore-rules", + "-c", "model_reasoning_effort=" + s.codexReasoningEffort, + "-c", "approval_policy=never", + "-", + } + + start := time.Now() + cmd := exec.CommandContext(cmdCtx, s.codexBin, args...) + cmd.Stdin = strings.NewReader(prompt) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + err = cmd.Run() + elapsed := time.Since(start) + exitCode := 0 + if err != nil { + exitCode = -1 + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode = exitErr.ExitCode() + } + } + + if cmdCtx.Err() != nil { + return nil, fmt.Errorf("codex resume timed out after %s", s.codexTimeout) + } + if err != nil { + return nil, fmt.Errorf("codex resume failed exit_code=%d stderr=%q: %w", exitCode, prefix(stderr.Bytes(), 300), err) + } + + output, err := os.ReadFile(outPath) + if err != nil { + return nil, fmt.Errorf("read codex last message: %w", err) + } + log.Printf("codex resume duration=%s session_id=%s out_size=%d out_prefix=%q", elapsed.Round(time.Millisecond), sessionID, len(output), prefix(output, 160)) + return output, nil +} + +// parseThreadID reads the first JSONL event from codex stdout and extracts +// the thread/session UUID from a `thread.started` event. +func parseThreadID(stdout []byte) (string, error) { + line, _, _ := bytes.Cut(stdout, []byte("\n")) + var ev struct { + Type string `json:"type"` + ThreadID string `json:"thread_id"` + } + if err := json.Unmarshal(line, &ev); err != nil { + return "", fmt.Errorf("parse thread.started event: %w (line=%q)", err, prefix(line, 200)) + } + if ev.Type != "thread.started" || ev.ThreadID == "" { + return "", fmt.Errorf("expected thread.started event with thread_id, got: %q", prefix(line, 200)) + } + return ev.ThreadID, nil +} + +func parseCodexOperations(stdout []byte) ([]*yokov1.GeneratedOperation, error) { + payload := extractJSONObject(stdout) + var parsed codexOutput + if err := json.Unmarshal(payload, &parsed); err != nil { + return nil, err + } + + ops := make([]*yokov1.GeneratedOperation, 0, len(parsed.Operations)) + for _, op := range parsed.Operations { + ops = append(ops, &yokov1.GeneratedOperation{ + Name: op.Name, + Body: op.Body, + Kind: operationKind(op.Kind), + Description: op.Description, + }) + } + return ops, nil +} + +func operationKind(kind string) yokov1.OperationKind { + switch strings.ToLower(kind) { + case "query": + return yokov1.OperationKind_OPERATION_KIND_QUERY + case "mutation": + return yokov1.OperationKind_OPERATION_KIND_MUTATION + default: + return yokov1.OperationKind_OPERATION_KIND_UNSPECIFIED + } +} + +// extractJSONObject returns the substring from the first '{' to the last '}' +// in stdout. Resume calls don't support --output-schema, so this guards +// against occasional preamble or trailing prose so json.Unmarshal still +// succeeds. +func extractJSONObject(stdout []byte) []byte { + start := bytes.IndexByte(stdout, '{') + end := bytes.LastIndexByte(stdout, '}') + if start < 0 || end < 0 || end < start { + return stdout + } + return stdout[start : end+1] +} + +func prefix(value []byte, max int) string { + if len(value) <= max { + return string(value) + } + return string(value[:max]) +} diff --git a/demo/code-mode/yoko-mock/main_test.go b/demo/code-mode/yoko-mock/main_test.go new file mode 100644 index 0000000000..61b0ea3d4f --- /dev/null +++ b/demo/code-mode/yoko-mock/main_test.go @@ -0,0 +1,169 @@ +package main + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" + "google.golang.org/protobuf/proto" +) + +func TestIndexThenSearchReturnsGeneratedOperations(t *testing.T) { + writeFakeCodex(t, + `{"type":"thread.started","thread_id":"fake-thread"}`, + `{"operations":[{"name":"getViewer","body":"query getViewer { viewer { id } }","kind":"query","description":"Fetches the current viewer."}]}`, + ) + client := newTestClient(t) + + indexResp, err := client.Index(context.Background(), connect.NewRequest(&yokov1.IndexRequest{ + SchemaSdl: "type Query { viewer: User } type User { id: ID! }", + })) + require.NoError(t, err) + + searchResp, err := client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ + SchemaId: indexResp.Msg.GetSchemaId(), + Prompts: []string{"get the viewer"}, + SessionId: "session-1", + })) + require.NoError(t, err) + + expected := &yokov1.SearchResponse{ + Operations: []*yokov1.GeneratedOperation{ + { + Name: "getViewer", + Body: "query getViewer { viewer { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetches the current viewer.", + }, + }, + } + assert.Equal(t, normalizeSearchResponse(t, expected), normalizeSearchResponse(t, searchResp.Msg)) +} + +func TestSearchUnknownSchemaIDReturnsNotFound(t *testing.T) { + writeFakeCodex(t, + `{"type":"thread.started","thread_id":"fake-thread"}`, + `{"operations":[]}`, + ) + client := newTestClient(t) + + _, err := client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ + SchemaId: "unknown", + Prompts: []string{"get the viewer"}, + })) + + require.Error(t, err) + assert.Equal(t, connect.CodeNotFound, connect.CodeOf(err)) +} + +func TestSearchBadJSONReturnsInternal(t *testing.T) { + writeFakeCodex(t, + `{"type":"thread.started","thread_id":"fake-thread"}`, + `not json`, + ) + client := newTestClient(t) + + indexResp, err := client.Index(context.Background(), connect.NewRequest(&yokov1.IndexRequest{ + SchemaSdl: "type Query { viewer: ID! }", + })) + require.NoError(t, err) + + _, err = client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ + SchemaId: indexResp.Msg.GetSchemaId(), + Prompts: []string{"get the viewer"}, + })) + + require.Error(t, err) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) +} + +func newTestClient(t *testing.T) yokov1connect.YokoServiceClient { + t.Helper() + + svc, err := newYokoService("codex", time.Second, "low", 0, 16) // disable rotation; small cache + require.NoError(t, err) + t.Cleanup(svc.Close) + mux := newHTTPMux(svc) + httpClient := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + return rec.Result(), nil + })} + + return yokov1connect.NewYokoServiceClient(httpClient, "http://yoko.test") +} + +// writeFakeCodex installs a stub `codex` binary on PATH that mocks both the +// initial `codex exec` (Index pre-warm) and `codex exec resume` (Search) calls. +// The stub detects "resume" in its argv to switch modes. +// +// - indexStdout is printed to stdout for the Index call (e.g. a JSONL line +// like {"type":"thread.started","thread_id":"..."}). +// - resumeMessage is written to the file passed via -o for the Search +// call (codex's --output-last-message contract). +func writeFakeCodex(t *testing.T, indexStdout, resumeMessage string) { + t.Helper() + + dir := t.TempDir() + indexFile := filepath.Join(dir, "index.out") + require.NoError(t, os.WriteFile(indexFile, []byte(indexStdout+"\n"), 0o644)) + resumeFile := filepath.Join(dir, "resume.out") + require.NoError(t, os.WriteFile(resumeFile, []byte(resumeMessage), 0o644)) + + name := "codex" + if runtime.GOOS == "windows" { + name += ".bat" + } + path := filepath.Join(dir, name) + var script string + if runtime.GOOS == "windows" { + // Minimal Windows fallback — only Index path is exercised in CI on Unix. + script = "@echo off\r\ntype \"" + indexFile + "\"\r\n" + } else { + script = "#!/bin/sh\n" + + "is_resume=0\n" + + "out_file=\"\"\n" + + "prev=\"\"\n" + + "for arg in \"$@\"; do\n" + + " if [ \"$prev\" = \"-o\" ]; then out_file=\"$arg\"; fi\n" + + " if [ \"$arg\" = \"resume\" ]; then is_resume=1; fi\n" + + " prev=\"$arg\"\n" + + "done\n" + + "cat >/dev/null\n" + + "if [ \"$is_resume\" = \"1\" ]; then\n" + + " if [ -n \"$out_file\" ]; then cat \"" + resumeFile + "\" > \"$out_file\"; fi\n" + + "else\n" + + " cat \"" + indexFile + "\"\n" + + "fi\n" + } + require.NoError(t, os.WriteFile(path, []byte(script), 0o755)) + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) +} + +var _ http.Handler = (*http.ServeMux)(nil) + +func normalizeSearchResponse(t *testing.T, resp *yokov1.SearchResponse) *yokov1.SearchResponse { + t.Helper() + + data, err := proto.Marshal(resp) + require.NoError(t, err) + normalized := &yokov1.SearchResponse{} + require.NoError(t, proto.Unmarshal(data, normalized)) + return normalized +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/demo/code-mode/yoko-mock/schema.graphql b/demo/code-mode/yoko-mock/schema.graphql new file mode 100644 index 0000000000..ed60f8b07b --- /dev/null +++ b/demo/code-mode/yoko-mock/schema.graphql @@ -0,0 +1,825 @@ +schema { + query: Query +} + +""" +Pylon — customer support tickets, accounts, contacts, surveys. +All types vendor-prefixed with `pylon_` to keep the federated schema collision-free. +""" +scalar pylon_DateTime + +scalar pylon_JSON + +""" +Slack — read shared customer channels, message history, user info. +All types vendor-prefixed with `slack_`. +""" +type Query { + pylon_searchAccounts(input: pylon_SearchAccountsInput!): [pylon_Account!]! + pylon_getAccount(id: ID!): pylon_Account + pylon_searchIssues(input: pylon_SearchIssuesInput!): [pylon_Issue!]! + pylon_getIssue(id: ID!): pylon_Issue + pylon_listIssueMessages(issueId: ID!): [pylon_Message!]! + pylon_listContacts(input: pylon_ListContactsInput!): [pylon_Contact!]! + pylon_listSurveys(limit: Int, cursor: String): [pylon_Survey!]! + pylon_getSurveyResponses(surveyId: ID!): [pylon_SurveyResponse!]! + """ + Canonical custom-field slugs: arr, renewal_date, segment, csm_owner, health_status. + """ + pylon_listAccountCustomFields: [pylon_CustomField!]! + pylon_listUsers: [pylon_User!]! + linear_customer(id: ID!): linear_Customer + linear_customerByDomain(domain: String!): linear_Customer + linear_customers(filter: linear_CustomersFilter): [linear_Customer!]! + linear_customerNeeds(filter: linear_CustomerNeedsFilter): [linear_CustomerNeed!]! + linear_issue(id: ID!): linear_Issue + linear_issues(filter: linear_IssuesFilter): [linear_Issue!]! + linear_team(id: ID!): linear_Team + linear_teams: [linear_Team!]! + linear_cycles(teamId: ID): [linear_Cycle!]! + linear_project(id: ID!): linear_Project + linear_user(id: ID!): linear_User + """ + Run a HogQL query. Two modes are supported: + + 1. **Preset (recommended)** — set `input.preset` to one of the values + of `posthog_HogQLPreset`. These map 1:1 to the curated mock handlers + and are guaranteed to execute. Pass any required parameters via + `input.presetParams` (see each enum value's docs for what it needs). + + 2. **Freeform** — set `input.hogql` to a raw HogQL string. Only the exact + shapes recognised by the preset matchers will succeed; any other query + returns Unimplemented with the supported handler list. + + If both `preset` and `hogql` are set, `preset` takes precedence. + """ + posthog_query(input: posthog_QueryInput!): posthog_QueryResult! + """List groups (customer orgs) for a given group type index.""" + posthog_listGroups(input: posthog_ListGroupsInput!): [posthog_Group!]! + posthog_getGroup(typeIndex: Int!, key: String!): posthog_Group + posthog_listEvents(input: posthog_ListEventsInput!): posthog_ListEventsResult! + """ + List feature flags configured for the project (used for feature-adoption mapping). + """ + posthog_listFeatureFlags(limit: Int): [posthog_FeatureFlag!]! + """Poll an async query result by id.""" + posthog_getAsyncQueryStatus(queryId: String!): posthog_AsyncQueryStatus! + circleback_searchMeetings(input: circleback_SearchMeetingsInput!): [circleback_Meeting!]! + circleback_readMeetings(meetingIds: [ID!]!): [circleback_Meeting!]! + circleback_getTranscripts(meetingIds: [ID!]!): [circleback_Transcript!]! + circleback_searchTranscripts(query: String!, limit: Int): [circleback_TranscriptHit!]! + circleback_searchActionItems(query: String!, limit: Int): [circleback_ActionItem!]! + circleback_findDomains(query: String!): [circleback_Domain!]! + circleback_findProfiles(query: String!): [circleback_Profile!]! + circleback_searchCalendarEvents(input: circleback_SearchCalendarInput!): [circleback_CalendarEvent!]! + circleback_listTags: [String!]! + slack_listChannels(input: slack_ListChannelsInput!): [slack_Channel!]! + slack_getChannel(channelId: ID!): slack_Channel + slack_history(input: slack_HistoryInput!): slack_HistoryResult! + slack_replies(channelId: ID!, threadTs: String!, limit: Int): [slack_Message!]! + slack_userInfo(userId: ID!): slack_User + slack_listUsers(limit: Int, cursor: String): [slack_User!]! + slack_searchMessages(query: String!, count: Int, page: Int): [slack_Message!]! + slack_authTest: slack_AuthTestResult! + notion_search(input: notion_SearchInput!): [notion_SearchResult!]! + notion_getPage(pageId: ID!): notion_Page + notion_getDatabase(databaseId: ID!): notion_Database + notion_queryDataSource(input: notion_QueryDataSourceInput!): notion_QueryDataSourceResult! + notion_getBlockChildren(blockId: ID!, limit: Int, cursor: String): notion_BlockChildrenResult! + notion_listUsers(limit: Int, cursor: String): [notion_User!]! +} + +type pylon_Account { + id: ID! + name: String! + domains: [String!]! + tags: [String!]! + customFields: [pylon_CustomFieldValue!]! +} + +type pylon_CustomField { + slug: String! + label: String! + objectType: String! + type: String! +} + +type pylon_CustomFieldValue { + slug: String! + label: String! + value: String! +} + +type pylon_Issue { + id: ID! + title: String! + state: String! + accountId: ID + priority: pylon_IssuePriority! + number: Int! + assignee: pylon_User + requester: pylon_Contact + tags: [String!]! + latestMessageAt: pylon_DateTime + slaBreached: Boolean! + createdAt: pylon_DateTime! + resolvedAt: pylon_DateTime + firstResponseSeconds: Int + resolutionSeconds: Int + resolutionBreachTime: pylon_DateTime + numberOfTouches: Int + externalIssues: [pylon_ExternalIssueLink!]! + csatResponses: [pylon_SurveyResponse!]! +} + +enum pylon_IssuePriority { + P1 + P2 + P3 + P4 +} + +type pylon_ExternalIssueLink { + source: String! + externalId: String! + url: String +} + +type pylon_Message { + id: ID! + issueId: ID! + authorId: ID + body: String! + createdAt: pylon_DateTime! +} + +type pylon_Contact { + id: ID! + email: String + name: String + accountId: ID +} + +type pylon_User { + id: ID! + name: String! + email: String! +} + +type pylon_Survey { + id: ID! + type: pylon_SurveyType! + name: String! +} + +enum pylon_SurveyType { + CSAT + NPS + CES +} + +type pylon_SurveyResponse { + id: ID! + surveyId: ID! + accountId: ID + contactId: ID + score: Int! + comment: String + createdAt: pylon_DateTime! +} + +input pylon_SearchAccountsInput { + name: String + domain: String + tag: String + limit: Int + cursor: String +} + +input pylon_SearchIssuesInput { + accountId: ID + state: String + createdAfter: pylon_DateTime + createdBefore: pylon_DateTime + resolvedAfter: pylon_DateTime + resolvedBefore: pylon_DateTime + priority: pylon_IssuePriority + tags: [String!] + slaBreached: Boolean + limit: Int +} + +input pylon_ListContactsInput { + accountId: ID + email: String + limit: Int + cursor: String +} + +""" +Linear — engineering issues, projects, and the native Customer entity. +All types vendor-prefixed with `linear_`. +""" +scalar linear_DateTime + +type linear_Customer { + id: ID! + name: String! + domains: [String!]! + externalIds: [String!]! + revenue: Float + size: Int + ownerId: ID + slackChannelId: String +} + +type linear_CustomerNeed { + id: ID! + customerId: ID! + issueId: ID + projectId: ID + important: Boolean! + body: String + createdAt: linear_DateTime! +} + +type linear_Issue { + id: ID! + identifier: String! + title: String! + description: String + priority: Int! + priorityLabel: String! + state: linear_IssueState! + needs: [linear_CustomerNeed!]! + customerTicketCount: Int! + teamId: ID! + assigneeId: ID + cycleId: ID + projectId: ID + labels: [String!]! + url: String! + createdAt: linear_DateTime! + updatedAt: linear_DateTime! + completedAt: linear_DateTime + addedToCycleAt: linear_DateTime +} + +enum linear_IssueState { + TRIAGE + BACKLOG + UNSTARTED + STARTED + COMPLETED + CANCELED + DUPLICATE +} + +type linear_Project { + id: ID! + name: String! + description: String + state: String! + health: linear_ProjectHealth! + progress: Float! + leadId: ID + teamId: ID! + startDate: linear_DateTime + targetDate: linear_DateTime + url: String! +} + +enum linear_ProjectHealth { + ON_TRACK + AT_RISK + OFF_TRACK +} + +type linear_Cycle { + id: ID! + teamId: ID! + number: Int! + name: String + startsAt: linear_DateTime! + endsAt: linear_DateTime! + progress: Float! +} + +type linear_Team { + id: ID! + key: String! + name: String! +} + +type linear_User { + id: ID! + name: String! + email: String! + active: Boolean! +} + +input linear_CustomersFilter { + domain: String + externalId: String + search: String + limit: Int +} + +input linear_CustomerNeedsFilter { + customerId: ID + createdAfter: linear_DateTime + createdBefore: linear_DateTime + important: Boolean + limit: Int +} + +input linear_IssuesFilter { + customerId: ID + teamId: ID + cycleId: ID + priority: Int + state: linear_IssueState + createdAfter: linear_DateTime + createdBefore: linear_DateTime + updatedAfter: linear_DateTime + limit: Int +} + +""" +PostHog — product telemetry queryable via HogQL. +Customers/orgs are modeled as PostHog `groups` (typeIndex 0..4). +All types vendor-prefixed with `posthog_`. + +Mock auth requires both X-Posthog-Token and X-Posthog-Project-Id metadata. +""" +scalar posthog_JSON + +type posthog_QueryResult { + columns: [String!]! + types: [String!]! + rows: [posthog_JSON!]! + hasMore: Boolean! + queryId: String + asyncStatus: posthog_AsyncQueryStatus +} + +type posthog_AsyncQueryStatus { + queryId: String! + state: posthog_AsyncQueryState! + errorMessage: String +} + +enum posthog_AsyncQueryState { + PENDING + RUNNING + COMPLETED + ERROR +} + +input posthog_QueryInput { + """ + Pre-defined query to execute. Recommended path — runs a curated mock + handler and is guaranteed to succeed. Takes precedence over `hogql` + when both are set. + """ + preset: posthog_HogQLPreset + """ + Parameters consumed by the chosen `preset`. See each enum value's docs + for which fields it requires. + """ + presetParams: posthog_HogQLPresetParams + """ + Freeform HogQL string. Only the exact shapes recognised by the preset + matchers will succeed; any other query returns Unimplemented. + Prefer `preset` for new callers. Omit when `preset` is set. + """ + hogql: String + refresh: posthog_RefreshMode + filtersOverride: posthog_JSON +} + +""" +Pre-defined HogQL queries available in the mock. Each value runs a +curated handler over the seeded event data and is guaranteed to execute. +""" +enum posthog_HogQLPreset { + """ + Quarter-over-quarter event count for one company. + Required params: `domain` (e.g. "ebay.com"). + Returns rows of (quarter DateTime, events UInt64). + """ + QOQ_COMPANY + """ + Daily-active-users + per-day event count for one company over the + last 30 days. + Required params: `domain`. + Returns rows of (day Date, dau UInt64, events UInt64). + """ + DAU_TIMESERIES_COMPANY + """ + Accounts whose 30-day event volume dropped >20% versus the prior 30 + days, sorted ascending by pct_change (most-at-risk first). + Required params: none. + Returns rows of (key, recent, prior, delta, pct_change). + """ + AT_RISK_ACCOUNTS + """ + Feature-adoption matrix across the top-10 customers (one row per + customer × feature_slug used). + Required params: none. + Returns rows of (key, feature_slug, uses). + """ + FEATURE_ADOPTION_TOP10 + """ + P95 request latency bucketed hourly for one company in a time window. + Required params: `domain`, `start`, `end` (RFC3339, e.g. + "2026-04-21T14:00:00Z"). + Returns rows of (hour DateTime, p95_ms Float64). + """ + LATENCY_HOURLY_COMPANY + """ + Per-event-name count for one company in a time window. + Required params: `domain`, `start` (RFC3339); optional: `end` (RFC3339). + Returns rows of (event String, count UInt64). + """ + EVENT_BREAKDOWN_WINDOW + """ + Week-over-week event delta across the entire portfolio (per group_0), + sorted by delta descending. + Required params: none. + Returns rows of (key, this_count, prev_count, delta). + """ + WEEKLY_PORTFOLIO_DELTA +} + +""" +Parameters for a `posthog_HogQLPreset` query. Only the fields required +by the chosen preset need to be set; extras are ignored. +""" +input posthog_HogQLPresetParams { + """Group key, typically a customer domain like "ebay.com".""" + domain: String + """ + RFC3339 start timestamp with time component, e.g. "2026-04-21T14:00:00Z". + """ + start: String + """ + RFC3339 end timestamp with time component, e.g. "2026-04-21T15:00:00Z". + """ + end: String +} + +enum posthog_RefreshMode { + BLOCKING + ASYNC + LAZY + FORCE_BLOCKING + FORCE_ASYNC +} + +type posthog_Group { + typeIndex: Int! + key: String! + properties: posthog_JSON! + createdAt: String +} + +input posthog_ListGroupsInput { + typeIndex: Int! + search: String + limit: Int +} + +type posthog_Event { + timestamp: String! + distinctId: String! + event: String! + group0: String! + properties: posthog_JSON! +} + +type posthog_ListEventsResult { + events: [posthog_Event!]! + nextCursor: String + hasMore: Boolean! +} + +input posthog_ListEventsInput { + groupKey: String + eventName: String + startTime: String + endTime: String + limit: Int + cursor: String +} + +type posthog_FeatureFlag { + id: ID! + key: String! + name: String + active: Boolean! + filters: posthog_JSON! + topRolloutPercentage: Float +} + +""" +Circleback — meeting transcripts, summaries, action items. +This subgraph serves deterministic embedded Circleback-style mock fixtures. +All types vendor-prefixed with `circleback_`. +""" +scalar circleback_DateTime + +type circleback_Meeting { + id: ID! + name: String! + createdAt: circleback_DateTime! + duration: Int! + url: String + recordingUrl: String + tags: [String!]! + attendees: [circleback_Attendee!]! + notes: String + actionItems: [circleback_ActionItem!]! + icalUid: String + organizerEmail: String! +} + +type circleback_Attendee { + email: String! + name: String +} + +enum circleback_ActionItemStatus { + PENDING + DONE +} + +type circleback_ActionItem { + id: ID! + meetingId: ID! + meetingName: String + title: String! + description: String + status: circleback_ActionItemStatus! + assignee: circleback_Attendee +} + +type circleback_Transcript { + meetingId: ID! + segments: [circleback_TranscriptSegment!]! +} + +type circleback_TranscriptSegment { + speaker: String! + startMs: Int! + endMs: Int! + text: String! +} + +type circleback_TranscriptHit { + meetingId: ID! + speaker: String! + text: String! + startMs: Int! + score: Float! +} + +type circleback_Domain { + domain: String! + meetingCount: Int! +} + +type circleback_Profile { + email: String! + name: String + domain: String + meetingCount: Int! +} + +type circleback_CalendarEvent { + id: ID! + title: String! + startsAt: circleback_DateTime! + endsAt: circleback_DateTime! + attendees: [circleback_Attendee!]! + icalUid: String! + tags: [String!]! + notes: String + actionItems: [circleback_ActionItem!]! + organizerEmail: String! +} + +input circleback_SearchMeetingsInput { + attendeeEmail: String + attendeeDomain: String + tag: String + keyword: String + startDate: circleback_DateTime + endDate: circleback_DateTime + limit: Int +} + +input circleback_SearchCalendarInput { + query: String + startDate: circleback_DateTime + endDate: circleback_DateTime + limit: Int +} + +type slack_Channel { + id: ID! + name: String! + isPrivate: Boolean! + isArchived: Boolean! + isMember: Boolean! + isShared: Boolean! + isExtShared: Boolean! + created: Int! + creatorId: ID! + topic: String + purpose: String + numMembers: Int +} + +type slack_Message { + ts: String! + channelId: ID! + userId: ID + text: String! + threadTs: String + permalink: String! + username: String + subtype: String + editedTs: String + replyCount: Int + replyUsersCount: Int + latestReplyTs: String + reactions: [slack_Reaction!]! +} + +type slack_Reaction { + name: String! + count: Int! + userIds: [ID!]! +} + +type slack_HistoryResult { + messages: [slack_Message!]! + hasMore: Boolean! + nextCursor: String +} + +type slack_User { + id: ID! + name: String! + realName: String + email: String + title: String + displayName: String + image: String + tz: String + isBot: Boolean! + deleted: Boolean! +} + +type slack_AuthTestResult { + ok: Boolean! + url: String! + team: String! + user: String! + teamId: ID! + userId: ID! + botId: ID +} + +input slack_ListChannelsInput { + types: [slack_ChannelType!] + excludeArchived: Boolean + namePrefix: String + limit: Int +} + +enum slack_ChannelType { + PUBLIC_CHANNEL + PRIVATE_CHANNEL + MPIM + IM +} + +input slack_HistoryInput { + channelId: ID! + oldest: String + latest: String + limit: Int + cursor: String +} + +""" +Notion — internal knowledge base / feature documentation. +All types vendor-prefixed with `notion_`. +""" +scalar notion_JSON + +scalar notion_DateTime + +type notion_Page { + id: ID! + """ + Synthesized convenience field. Notion stores page titles in the title property + (usually properties["Name"].title[0].plain_text), and the resolver derives this value. + """ + title: String! + url: String! + parentId: ID + parentType: notion_ParentType + createdAt: notion_DateTime! + updatedAt: notion_DateTime! + archived: Boolean! + properties: notion_JSON! +} + +type notion_Database { + id: ID! + title: String! + url: String! + createdAt: notion_DateTime! + updatedAt: notion_DateTime! + dataSources: [notion_DataSource!]! +} + +type notion_DataSource { + id: ID! + databaseId: ID! + name: String +} + +type notion_Block { + id: ID! + type: String! + hasChildren: Boolean! + archived: Boolean! + content: notion_JSON! +} + +type notion_BlockChildrenResult { + blocks: [notion_Block!]! + hasMore: Boolean! + nextCursor: String +} + +type notion_QueryDataSourceResult { + pages: [notion_Page!]! + hasMore: Boolean! + nextCursor: String +} + +type notion_SearchResult { + objectType: notion_ObjectType! + id: ID! + title: String! + url: String! +} + +enum notion_ObjectType { + PAGE + DATABASE + DATA_SOURCE +} + +enum notion_ParentType { + PAGE + DATABASE + DATA_SOURCE + WORKSPACE + BLOCK +} + +type notion_User { + id: ID! + name: String! + email: String + type: notion_UserType! +} + +enum notion_UserType { + PERSON + BOT +} + +input notion_SearchInput { + query: String + filterType: notion_ObjectType + limit: Int +} + +input notion_QueryDataSourceInput { + dataSourceId: ID! + """ + Best-effort Notion-style property equals filter. The mock supports Status, + Segment, Health, Slug, and Domain; unsupported filter JSON is ignored. + """ + filter: notion_JSON + sorts: notion_JSON + limit: Int + cursor: String +} \ No newline at end of file diff --git a/demo/pkg/subgraphs/employees/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/employees/subgraph/schema.resolvers.go index 3f0d30d16b..ecbf4d61f4 100644 --- a/demo/pkg/subgraphs/employees/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/employees/subgraph/schema.resolvers.go @@ -61,25 +61,9 @@ func (r *mutationResolver) UpdateEmployeeTag(ctx context.Context, id int, tag st defer r.mux.Unlock() for _, employee := range r.EmployeesData { if id == employee.ID { - details := &model.Details{} - if employee.Details != nil { - details.Forename = employee.Details.Forename - details.Surname = employee.Details.Surname - details.Location = employee.Details.Location - } - return &model.Employee{ - ID: employee.ID, - Details: details, - Tag: tag, - Expertise: employee.Expertise, - Role: employee.Role, - Notes: employee.Notes, - UpdatedAt: time.Now().String(), - StartDate: employee.StartDate, - PrimaryWorkItem: employee.PrimaryWorkItem, - LastWorkReview: employee.LastWorkReview, - WorkSetup: employee.WorkSetup, - }, nil + employee.Tag = tag + employee.UpdatedAt = time.Now().String() + return employee, nil } } return nil, nil diff --git a/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto b/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto new file mode 100644 index 0000000000..3add37193b --- /dev/null +++ b/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto @@ -0,0 +1,84 @@ +syntax = "proto3"; + +package wundergraph.cosmo.code_mode.yoko.v1; + +// Yoko generates GraphQL operations for natural-language prompts. +// +// Two-step flow: +// 1. Index(schema_sdl) -> schema_id. Idempotent: the same SDL +// always returns the same schema_id for as long as the index +// is retained. The router caches schema_id and re-indexes only +// on supergraph reload (or when Search returns NOT_FOUND). +// 2. Search(prompts, schema_id) -> operations. Yoko owns prompt +// fan-out, partial-failure handling, cross-prompt deduplication, +// and ranking. +// +// The router never sends a schema hash. Yoko is the sole authority +// on schema identity; the router only sends raw SDL on Index and an +// opaque id on Search. +service YokoService { + rpc Index(IndexRequest) returns (IndexResponse); + rpc Search(SearchRequest) returns (SearchResponse); +} + +message IndexRequest { + // The supergraph SDL to index. Sent in full on every Index call; + // Yoko deduplicates internally and is free to short-circuit when + // the SDL is already known. + string schema_sdl = 1; +} + +message IndexResponse { + // Opaque, Yoko-assigned identifier for this schema. Stable for as + // long as Yoko retains the index. Subsequent Search calls pass this + // back instead of the full SDL. Idempotent: the same SDL returns + // the same schema_id. + string schema_id = 1; +} + +message SearchRequest { + // Batch of natural-language prompts. Bounded at 20 by the host. + repeated string prompts = 1; + + // Identifier returned by a prior Index call. If Yoko no longer + // recognizes the id (e.g. eviction, restart), it MUST return the + // Connect error code NOT_FOUND; the router re-indexes and retries + // the call exactly once. + string schema_id = 2; + + // Opaque MCP session ID for telemetry correlation only. + // Yoko MUST NOT use this for stateful behavior — sessions are owned + // by the router. + string session_id = 3; +} + +message SearchResponse { + // Operations across all prompts, already deduplicated and ranked. + // Order is significant: earlier entries rank higher and are preferred + // when bundle truncation drops from the tail. + repeated GeneratedOperation operations = 1; +} + +message GeneratedOperation { + // Suggested operation name (camelCase preferred). The host applies + // its own identifier normalization and in-session collision-suffix + // logic on top of this — see §6. + string name = 1; + + // GraphQL operation body (query or mutation source text). + string body = 2; + + // Operation kind. Subscriptions are out of scope; if Yoko returns + // one, the host drops it with a single warn log. + OperationKind kind = 3; + + // Human-readable description, surfaced as JSDoc on the typed + // `tools.` signature in the rendered bundle. + string description = 4; +} + +enum OperationKind { + OPERATION_KIND_UNSPECIFIED = 0; + OPERATION_KIND_QUERY = 1; + OPERATION_KIND_MUTATION = 2; +} diff --git a/router-tests/code_mode_named_ops_test.go b/router-tests/code_mode_named_ops_test.go new file mode 100644 index 0000000000..40ea76a5fb --- /dev/null +++ b/router-tests/code_mode_named_ops_test.go @@ -0,0 +1,621 @@ +package integration + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "connectrpc.com/connect" + miniredis "github.com/alicebob/miniredis/v2" + mark3mcp "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zapcore" + + "github.com/wundergraph/cosmo/router-tests/freeport" + "github.com/wundergraph/cosmo/router-tests/testenv" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + yokoconnect "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/controlplane/configpoller" + "github.com/wundergraph/cosmo/router/pkg/routerconfig" +) + +const codeModePersistedOpsURI = "yoko://persisted-ops.d.ts" + +const ( + firstEmployeeOpName = "firstEmployee" + employeeByIDOpName = "employeeByID" + updateTagOpName = "updateEmployeeTag" + + firstEmployeeQuery = `query firstEmployee { firstEmployee { id details { forename surname } } }` + employeeByIDQuery = `query employeeByID($id: Int!) { employee(id: $id) { id details { forename surname } } }` + updateTagMutation = `mutation updateEmployeeTag($id: Int!, $tag: String!) { updateEmployeeTag(id: $id, tag: $tag) { id tag } }` +) + +const firstEmployeeTS = `/** Fetch the first employee. */ +firstEmployee(): R<{ firstEmployee: { id: number; details: { forename: string; surname: string } | null } }>;` + +const employeeByIDTS = `/** Fetch employee by id. */ +employeeByID(vars: { id: number }): R<{ employee: { id: number; details: { forename: string; surname: string } | null } | null }>;` + +const updateTagTS = `/** Update employee tag. */ +updateEmployeeTag(vars: { id: number; tag: string }): R<{ updateEmployeeTag: { id: number; tag: string } | null }>;` + +const twoOpsFragment = firstEmployeeTS + "\n\n" + employeeByIDTS + +// indentBundleEntry mirrors tsgen's behavior: every line of a per-op block +// (JSDoc + signature) is indented by 2 spaces inside the tools object. +func indentBundleEntry(s string) string { + return " " + strings.ReplaceAll(s, "\n", "\n ") +} + +const emptyOpsBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; +// Known limitation: union and interface selections are typed as unknown. + +declare const tools: {}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T;` + +var firstEmployeeBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; +// Known limitation: union and interface selections are typed as unknown. + +declare const tools: { +` + indentBundleEntry(firstEmployeeTS) + ` +}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T;` + +var employeeByIDBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; +// Known limitation: union and interface selections are typed as unknown. + +declare const tools: { +` + indentBundleEntry(employeeByIDTS) + ` +}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T;` + +var twoOpsBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; +// Known limitation: union and interface selections are typed as unknown. + +declare const tools: { +` + indentBundleEntry(firstEmployeeTS) + ` + +` + indentBundleEntry(employeeByIDTS) + ` +}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T;` + +type codeModeBackend struct { + name string + providerID string + redisURL string +} + +func TestCodeModeNamedOpsMemoryBackendStatefulSearchExecuteAndResource(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}, func(ctx context.Context, _ string, xEnv *testenv.Environment, yoko *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"first employee", "employee by id"}, + }) + assert.Equal(t, twoOpsFragment, searchText) + assert.Equal(t, []*yokov1.IndexRequest{{SchemaSdl: yoko.indexRequests()[0].GetSchemaSdl()}}, yoko.indexRequests()) + assert.Equal(t, []*yokov1.SearchRequest{{ + Prompts: []string{"first employee", "employee by id"}, + SchemaId: "schema-1", + SessionId: yoko.searchRequests()[0].GetSessionId(), + }}, yoko.searchRequests()) + + resource := readPersistedOpsResource(t, ctx, session) + assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ + URI: codeModePersistedOpsURI, + MIMEType: "text/plain", + Text: twoOpsBundle, + }}}, resource) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + }) + assert.Equal(t, map[string]any{ + "result": map[string]any{ + "data": map[string]any{ + "employee": map[string]any{ + "id": float64(1), + "details": map[string]any{ + "forename": "Jens", + "surname": "Neuse", + }, + }, + }, + }, + }, decodeJSON(t, executeText)) + + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `{ employee(id: 1) { id details { forename surname } } }`}) + assert.Equal(t, `{"data":{"employee":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}}`, res.Body) + }) +} + +func TestCodeModeNamedOpsConcurrentSessions(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}, func(ctx context.Context, endpoint string, _ *testenv.Environment, _ *fakeCodeModeYoko, sessionA *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, sessionA, "code_mode_search_tools", map[string]any{ + "prompts": []string{"first employee"}, + }) + assert.Equal(t, firstEmployeeTS, searchText) + + sessionB := newCodeModeMCPClient(t, ctx, endpoint, nil) + resourceA := readPersistedOpsResource(t, ctx, sessionA) + resourceB := readPersistedOpsResource(t, ctx, sessionB) + + assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ + URI: codeModePersistedOpsURI, + MIMEType: "text/plain", + Text: firstEmployeeBundle, + }}}, resourceA) + assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ + URI: codeModePersistedOpsURI, + MIMEType: "text/plain", + Text: emptyOpsBundle, + }}}, resourceB) + }) +} + +func TestCodeModeNamedOpsSchemaReloadEvictsSession(t *testing.T) { + poller := &codeModeConfigPoller{ready: make(chan struct{})} + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{poller: poller}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"employee by id"}, + }) + assert.Equal(t, employeeByIDTS, searchText) + + <-poller.ready + poller.initConfig.Version = "code-mode-reload" + require.NoError(t, poller.updateConfig(poller.initConfig, "before-code-mode-reload")) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + }) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "TypeError", + "message": "tools.employeeByID is not a function", + "stack": " at __agentMain (codemode_agent.js:agent.ts:1:34)\n at (codemode_agent.js:73:42)\n at (codemode_agent.js:77:1)\n", + }, + }, decodeJSON(t, executeText)) + }) +} + +func TestCodeModeNamedOpsMutationElicitationRejection(t *testing.T) { + decline := func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: "accept", Content: map[string]any{ + "approved": false, + "reason": "policy forbids", + }}, nil + } + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{elicitationHandler: decline}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"update employee tag"}, + }) + assert.Equal(t, updateTagTS, searchText) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { return await tools.updateEmployeeTag({ id: 1, tag: "x" }); }`, + }) + assert.Equal(t, map[string]any{ + "result": map[string]any{ + "data": nil, + "declined": map[string]any{ + "reason": "policy forbids", + }, + "errors": []any{ + map[string]any{"message": "Mutation declined by operator: policy forbids"}, + }, + }, + }, decodeJSON(t, executeText)) + }) +} + +func TestCodeModeNamedOpsTranspileError(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"employee by id"}, + }) + assert.Equal(t, employeeByIDTS, searchText) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { let x = ; }`, + }) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "TranspileError", + "message": "transpile failed: Unexpected \";\"", + "stack": "", + }, + }, decodeJSON(t, executeText)) + }) +} + +func TestCodeModeNamedOpsListResourcesGating(t *testing.T) { + t.Run("code mode disabled does not advertise persisted ops on main MCP server", func(t *testing.T) { + yoko := newFakeCodeModeYoko() + yokoServer := startFakeCodeModeYoko(t, yoko) + cfg := baseCodeModeTestConfig(t, yokoServer.URL, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}) + cfg.MCP.CodeMode.Enabled = false + + testenv.Run(t, cfg, func(t *testing.T, xEnv *testenv.Environment) { + resources, err := xEnv.MCPClient.ListResources(ctxWithTimeout(t), mark3mcp.ListResourcesRequest{}) + require.NoError(t, err) + assert.Equal(t, false, mark3ResourcesContain(resources.Resources, codeModePersistedOpsURI)) + }) + }) + + t.Run("named ops disabled does not advertise persisted ops", func(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{namedOpsEnabled: boolPtr(false)}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + resources, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, []*mcp.Resource{}, resources.Resources) + }) + }) + + t.Run("stateless does not advertise persisted ops and warns once", func(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{sessionStateless: boolPtr(true), observeLogs: true}, func(ctx context.Context, _ string, xEnv *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + resources, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, []*mcp.Resource{}, resources.Resources) + + logs := xEnv.Observer().FilterMessage("code mode named operations are disabled because MCP session stateless mode is enabled").All() + assert.Equal(t, 1, len(logs)) + }) + }) + + t.Run("all gates on advertises persisted ops and read returns bundle", func(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"employee by id"}, + }) + assert.Equal(t, employeeByIDTS, searchText) + + resources, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, []*mcp.Resource{{ + URI: codeModePersistedOpsURI, + Name: "persisted-ops.d.ts", + Title: "Persisted operations TypeScript definitions", + Description: "Cumulative TypeScript definitions for the current Code Mode MCP session's named operations.", + MIMEType: "text/plain", + }}, resources.Resources) + + resource := readPersistedOpsResource(t, ctx, session) + assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ + URI: codeModePersistedOpsURI, + MIMEType: "text/plain", + Text: employeeByIDBundle, + }}}, resource) + }) + }) +} + +func TestCodeModeNamedOpsRedisBackendTransparent(t *testing.T) { + redisServer, err := miniredis.Run() + if err != nil { + t.Skipf("miniredis unavailable: %v", err) + } + t.Cleanup(redisServer.Close) + + backend := codeModeBackend{ + name: "redis", + providerID: "code_mode_redis", + redisURL: "redis://" + redisServer.Addr(), + } + withCodeModeNamedOps(t, backend, codeModeNamedOpsOptions{}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"first employee", "employee by id"}, + }) + assert.Equal(t, twoOpsFragment, searchText) + + resource := readPersistedOpsResource(t, ctx, session) + assert.Equal(t, twoOpsBundle, resource.Contents[0].Text) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + }) + assert.Equal(t, map[string]any{ + "result": map[string]any{ + "data": map[string]any{ + "employee": map[string]any{ + "id": float64(1), + "details": map[string]any{ + "forename": "Jens", + "surname": "Neuse", + }, + }, + }, + }, + }, decodeJSON(t, executeText)) + }) +} + +type codeModeNamedOpsOptions struct { + namedOpsEnabled *bool + sessionStateless *bool + observeLogs bool + poller *codeModeConfigPoller + elicitationHandler func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) +} + +func withCodeModeNamedOps(t *testing.T, backend codeModeBackend, opts codeModeNamedOpsOptions, f func(context.Context, string, *testenv.Environment, *fakeCodeModeYoko, *mcp.ClientSession)) { + t.Helper() + + yoko := newFakeCodeModeYoko() + yokoServer := startFakeCodeModeYoko(t, yoko) + cfg := baseCodeModeTestConfig(t, yokoServer.URL, backend, opts) + + testenv.Run(t, cfg, func(t *testing.T, xEnv *testenv.Environment) { + ctx := ctxWithTimeout(t) + endpoint := "http://" + cfg.MCP.CodeMode.Server.ListenAddr + "/mcp" + session := newCodeModeMCPClient(t, ctx, endpoint, opts.elicitationHandler) + f(ctx, endpoint, xEnv, yoko, session) + }) +} + +func baseCodeModeTestConfig(t *testing.T, yokoURL string, backend codeModeBackend, opts codeModeNamedOpsOptions) *testenv.Config { + t.Helper() + + ports := freeport.GetN(t, 2) + namedOpsEnabled := true + if opts.namedOpsEnabled != nil { + namedOpsEnabled = *opts.namedOpsEnabled + } + sessionStateless := false + if opts.sessionStateless != nil { + sessionStateless = *opts.sessionStateless + } + + mcpCfg := config.MCPConfiguration{ + Enabled: true, + Server: config.MCPServer{ + ListenAddr: fmt.Sprintf("127.0.0.1:%d", ports[0]), + }, + Session: config.MCPSessionConfig{Stateless: sessionStateless}, + CodeMode: config.MCPCodeModeConfiguration{ + Enabled: true, + RequireMutationApproval: true, + ExecuteTimeout: 30 * time.Second, + MaxResultBytes: 32 << 10, + Server: config.MCPCodeModeServerConfig{ + ListenAddr: fmt.Sprintf("127.0.0.1:%d", ports[1]), + }, + QueryGeneration: config.MCPCodeModeQueryGenConfig{ + Enabled: true, + Endpoint: yokoURL, + Timeout: 5 * time.Second, + }, + NamedOps: config.MCPCodeModeNamedOpsConfig{ + Enabled: namedOpsEnabled, + SessionTTL: 30 * time.Minute, + MaxSessions: 100, + MaxBundleBytes: 256 << 10, + Storage: config.MCPCodeModeNamedOpsStorageConfig{ + ProviderID: backend.providerID, + KeyPrefix: "router_tests_code_mode", + }, + }, + }, + } + + cfg := &testenv.Config{ + MCP: mcpCfg, + MCPOperationsPath: "protocol/testdata/mcp_operations_collision", + CodeModeRedisURL: backend.redisURL, + } + if opts.observeLogs { + cfg.LogObservation = testenv.LogObservationConfig{Enabled: true, LogLevel: zapcore.WarnLevel} + } + if opts.poller != nil { + cfg.RouterConfig = &testenv.RouterConfig{ + ConfigPollerFactory: func(routerConfig *nodev1.RouterConfig) configpoller.ConfigPoller { + opts.poller.initConfig = routerConfig + return opts.poller + }, + } + } + return cfg +} + +func newCodeModeMCPClient(t *testing.T, ctx context.Context, endpoint string, elicitation func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error)) *mcp.ClientSession { + t.Helper() + + client := mcp.NewClient(&mcp.Implementation{Name: "router-tests", Version: "v0.0.0"}, &mcp.ClientOptions{ + ElicitationHandler: elicitation, + }) + transport := &mcp.StreamableClientTransport{ + Endpoint: endpoint, + DisableStandaloneSSE: true, + MaxRetries: -1, + } + session, err := client.Connect(ctx, transport, nil) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, session.Close()) + }) + return session +} + +func callCodeModeToolText(t *testing.T, ctx context.Context, session *mcp.ClientSession, name string, args map[string]any) string { + t.Helper() + result, err := session.CallTool(ctx, &mcp.CallToolParams{Name: name, Arguments: args}) + require.NoError(t, err) + require.False(t, result.IsError) + require.Len(t, result.Content, 1) + text, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok) + return text.Text +} + +func readPersistedOpsResource(t *testing.T, ctx context.Context, session *mcp.ClientSession) *mcp.ReadResourceResult { + t.Helper() + result, err := session.ReadResource(ctx, &mcp.ReadResourceParams{URI: codeModePersistedOpsURI}) + require.NoError(t, err) + return result +} + +func decodeJSON(t *testing.T, text string) map[string]any { + t.Helper() + var decoded map[string]any + require.NoError(t, json.Unmarshal([]byte(text), &decoded)) + return decoded +} + +func ctxWithTimeout(t *testing.T) context.Context { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + t.Cleanup(cancel) + return ctx +} + +func boolPtr(v bool) *bool { + return &v +} + +func mark3ResourcesContain(resources []mark3mcp.Resource, uri string) bool { + for _, resource := range resources { + if resource.URI == uri { + return true + } + } + return false +} + +type fakeCodeModeYoko struct { + mu sync.Mutex + indexCounter int + indexRequestLog []*yokov1.IndexRequest + searchRequestLog []*yokov1.SearchRequest + opsByPrompt map[string]*yokov1.GeneratedOperation +} + +func newFakeCodeModeYoko() *fakeCodeModeYoko { + return &fakeCodeModeYoko{ + opsByPrompt: map[string]*yokov1.GeneratedOperation{ + "first employee": { + Name: firstEmployeeOpName, + Body: firstEmployeeQuery, + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch the first employee.", + }, + "employee by id": { + Name: employeeByIDOpName, + Body: employeeByIDQuery, + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch employee by id.", + }, + "update employee tag": { + Name: updateTagOpName, + Body: updateTagMutation, + Kind: yokov1.OperationKind_OPERATION_KIND_MUTATION, + Description: "Update employee tag.", + }, + }, + } +} + +func startFakeCodeModeYoko(t *testing.T, svc *fakeCodeModeYoko) *httptest.Server { + t.Helper() + path, handler := yokoconnect.NewYokoServiceHandler(svc) + mux := http.NewServeMux() + mux.Handle(path, handler) + ports := freeport.GetN(t, 1) + listener, err := net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", ports[0])) + require.NoError(t, err) + server := httptest.NewUnstartedServer(mux) + server.Listener = listener + server.Start() + t.Cleanup(server.Close) + return server +} + +func (f *fakeCodeModeYoko) Index(_ context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + f.mu.Lock() + defer f.mu.Unlock() + f.indexCounter++ + f.indexRequestLog = append(f.indexRequestLog, &yokov1.IndexRequest{SchemaSdl: req.Msg.GetSchemaSdl()}) + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: fmt.Sprintf("schema-%d", f.indexCounter)}), nil +} + +func (f *fakeCodeModeYoko) Search(_ context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + f.mu.Lock() + defer f.mu.Unlock() + f.searchRequestLog = append(f.searchRequestLog, &yokov1.SearchRequest{ + Prompts: append([]string(nil), req.Msg.GetPrompts()...), + SchemaId: req.Msg.GetSchemaId(), + SessionId: req.Msg.GetSessionId(), + }) + ops := make([]*yokov1.GeneratedOperation, 0, len(req.Msg.GetPrompts())) + for _, prompt := range req.Msg.GetPrompts() { + if op := f.opsByPrompt[prompt]; op != nil { + ops = append(ops, op) + } + } + return connect.NewResponse(&yokov1.SearchResponse{Operations: ops}), nil +} + +func (f *fakeCodeModeYoko) indexRequests() []*yokov1.IndexRequest { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]*yokov1.IndexRequest, 0, len(f.indexRequestLog)) + for _, req := range f.indexRequestLog { + out = append(out, &yokov1.IndexRequest{SchemaSdl: req.GetSchemaSdl()}) + } + return out +} + +func (f *fakeCodeModeYoko) searchRequests() []*yokov1.SearchRequest { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]*yokov1.SearchRequest, 0, len(f.searchRequestLog)) + for _, req := range f.searchRequestLog { + out = append(out, &yokov1.SearchRequest{ + Prompts: append([]string(nil), req.GetPrompts()...), + SchemaId: req.GetSchemaId(), + SessionId: req.GetSessionId(), + }) + } + return out +} + +type codeModeConfigPoller struct { + initConfig *nodev1.RouterConfig + updateConfig func(newConfig *nodev1.RouterConfig, oldVersion string) error + ready chan struct{} + once sync.Once +} + +func (c *codeModeConfigPoller) Subscribe(_ context.Context, handler func(newConfig *nodev1.RouterConfig, oldVersion string) error) { + c.updateConfig = handler + c.once.Do(func() { close(c.ready) }) +} + +func (c *codeModeConfigPoller) GetRouterConfig(_ context.Context) (*routerconfig.Response, error) { + return &routerconfig.Response{Config: c.initConfig}, nil +} + +func (c *codeModeConfigPoller) Stop(_ context.Context) error { + return nil +} diff --git a/router-tests/go.mod b/router-tests/go.mod index 862856a1ae..e08d1f9168 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -5,6 +5,7 @@ go 1.25.0 require ( connectrpc.com/connect v1.19.1 github.com/MicahParks/jwkset v0.11.0 + github.com/alicebob/miniredis/v2 v2.34.0 github.com/buger/jsonparser v1.1.2 github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 github.com/golang-jwt/jwt/v5 v5.3.0 @@ -51,6 +52,7 @@ require ( github.com/KimMachineGun/automemlimit v0.6.1 // indirect github.com/MicahParks/keyfunc/v3 v3.6.2 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect + github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/benbjohnson/clock v1.3.0 // indirect @@ -74,7 +76,9 @@ require ( github.com/docker/docker-credential-helpers v0.9.3 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/evanw/esbuild v0.27.3 // indirect github.com/expr-lang/expr v1.17.7 // indirect + github.com/fastschema/qjs v0.0.6 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-chi/chi/v5 v5.2.2 // indirect @@ -146,6 +150,8 @@ require ( github.com/sosodev/duration v1.3.1 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/tdewolff/parse/v2 v2.8.12 // indirect + github.com/tetratelabs/wazero v1.9.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -161,6 +167,7 @@ require ( github.com/wundergraph/go-arena v1.1.0 // indirect github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect go.opentelemetry.io/contrib/propagators/b3 v1.23.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 28e214572e..85128dd4c4 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -87,8 +87,12 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/evanw/esbuild v0.27.3 h1:dH/to9tBKybig6hl25hg4SKIWP7U8COdJKbGEwnUkmU= +github.com/evanw/esbuild v0.27.3/go.mod h1:D2vIQZqV/vIf/VRHtViaUtViZmG7o+kKmlBfVQuRi48= github.com/expr-lang/expr v1.17.7 h1:Q0xY/e/2aCIp8g9s/LGvMDCC5PxYlvHgDZRQ4y16JX8= github.com/expr-lang/expr v1.17.7/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/fastschema/qjs v0.0.6 h1:C45KMmQMd21UwsUAmQHxUxiWOfzwTg1GJW0DA0AbFEE= +github.com/fastschema/qjs v0.0.6/go.mod h1:bbg36wxXnx8g0FdKIe5+nCubrQvHa7XEVWqUptjHt/A= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= @@ -333,6 +337,12 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tdewolff/parse/v2 v2.8.12 h1:5BBjfaCv482v3nltlS0u6wH1xJaxjR6ofDrWttNvROg= +github.com/tdewolff/parse/v2 v2.8.12/go.mod h1:Hwlni2tiVNKyzR1o6nUs4FOF07URA+JLBLd6dlIXYqo= +github.com/tdewolff/test v1.0.11 h1:FdLbwQVHxqG16SlkGveC0JVyrJN62COWTRyUFzfbtBE= +github.com/tdewolff/test v1.0.11/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 8f3c0a1a21..edca6036c8 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -367,6 +367,10 @@ type Config struct { MCP config.MCPConfiguration MCPOperationsPath string MCPAuthToken string // Optional Bearer token for MCP authentication + // CodeModeRedisURL, when paired with MCP.CodeMode.NamedOps.Storage.ProviderID, + // registers a Redis storage provider with that ID so the named-ops backend can + // resolve it from the central provider registry. + CodeModeRedisURL string EnableRedis bool EnableRedisCluster bool Plugins PluginConfig @@ -1520,14 +1524,23 @@ func configureRouter(listenerAddr string, testConfig *Config, routerConfig *node if testConfig.MCPOperationsPath != "" { mcpOperationsPath = testConfig.MCPOperationsPath } - routerOpts = append(routerOpts, core.WithStorageProviders(config.StorageProviders{ + storageProviders := config.StorageProviders{ FileSystem: []config.FileSystemStorageProvider{ { ID: "test", Path: mcpOperationsPath, }, }, - })) + } + // Append a Redis provider for code mode named ops when the test set a + // provider_id and supplied a URL via CodeModeRedisURL. + if id := testConfig.MCP.CodeMode.NamedOps.Storage.ProviderID; id != "" && testConfig.CodeModeRedisURL != "" { + storageProviders.Redis = append(storageProviders.Redis, config.RedisStorageProvider{ + ID: id, + URLs: []string{testConfig.CodeModeRedisURL}, + }) + } + routerOpts = append(routerOpts, core.WithStorageProviders(storageProviders)) testConfig.MCP.Storage.ProviderID = "test" diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 907d5abfbb..830b29e5a5 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -59,6 +59,7 @@ import ( rtrace "github.com/wundergraph/cosmo/router/pkg/trace" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" ) const ( @@ -1399,6 +1400,14 @@ func (s *graphServer) buildGraphMux( return nil, fmt.Errorf("failed to reload MCP server: %w", mErr) } } + if opts.IsBaseGraph() && s.codeModeServer != nil { + sdl, printErr := astprinter.PrintString(executor.ClientSchema) + if printErr != nil { + s.logger.Error("failed to reload MCP server", zap.Error(fmt.Errorf("failed to print Code Mode schema SDL: %w", printErr))) + } else if mErr := s.codeModeServer.Reload(executor.ClientSchema, sdl); mErr != nil { + s.logger.Error("failed to reload MCP server", zap.Error(mErr)) + } + } if s.cacheWarmup != nil && s.cacheWarmup.Enabled { processor := NewCacheWarmupPlanningProcessor(&CacheWarmupPlanningProcessorOptions{ diff --git a/router/core/router.go b/router/core/router.go index cb173417d3..5bcb17f985 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -27,6 +27,7 @@ import ( "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/graphqlmetrics/v1/graphqlmetricsv1connect" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" "github.com/wundergraph/cosmo/router/internal/circuit" + codemodeserver "github.com/wundergraph/cosmo/router/internal/codemode/server" "github.com/wundergraph/cosmo/router/internal/debug" "github.com/wundergraph/cosmo/router/internal/docker" "github.com/wundergraph/cosmo/router/internal/exporter" @@ -947,6 +948,9 @@ func (r *Router) bootstrap(ctx context.Context) error { if err := r.startMCPServer(ctx); err != nil { return err } + if err := r.startCodeModeServer(ctx); err != nil { + return err + } if r.connectRPC.Enabled { r.logger.Debug("ConnectRPC configuration", @@ -1150,6 +1154,76 @@ func (r *Router) startMCPServer(ctx context.Context) error { return nil } +// startCodeModeServer initializes and starts the separate Code Mode MCP server if enabled. +func (r *Router) startCodeModeServer(ctx context.Context) error { + var redisProvider *config.RedisStorageProvider + if r.mcp.CodeMode.Enabled && r.mcp.CodeMode.NamedOps.Enabled { + if providerID := r.mcp.CodeMode.NamedOps.Storage.ProviderID; providerID != "" { + provider, ok := r.providerRegistry.Redis(providerID) + if !ok { + return fmt.Errorf("redis storage provider with id '%s' for mcp code_mode named_ops not found", providerID) + } + redisProvider = &provider + } + } + + cm, err := codemodeserver.BuildFromConfig(codemodeserver.BuildOptions{ + Config: r.mcp.CodeMode, + SessionStateless: r.mcp.Session.Stateless, + RouterGraphQLURL: r.graphqlEndpointURL, + Logger: r.logger, + TracerProvider: r.tracerProvider, + MeterProvider: r.otlpMeterProvider, + RedisProvider: redisProvider, + RedisFactory: func(opts *rd.RedisCloserOptions) (rd.RDCloser, error) { + if opts.Logger == nil { + opts.Logger = r.logger + } + return rd.NewRedisCloser(opts) + }, + }) + if err != nil { + return fmt.Errorf("failed to create code mode MCP server: %w", err) + } + r.codeModeServer = cm + + if !r.mcp.CodeMode.Enabled { + return nil + } + + errs := make(chan error, 1) + go func() { + errs <- cm.Start(ctx) + }() + + deadline := time.NewTimer(5 * time.Second) + defer deadline.Stop() + tick := time.NewTicker(10 * time.Millisecond) + defer tick.Stop() + for { + select { + case err := <-errs: + if err != nil { + return fmt.Errorf("failed to start code mode MCP server: %w", err) + } + return nil + case <-ctx.Done(): + return ctx.Err() + case <-deadline.C: + return fmt.Errorf("failed to start code mode MCP server: listener was not bound") + case <-tick.C: + if cm.Addr() != "" { + go func() { + if err := <-errs; err != nil { + r.logger.Error("Code Mode MCP server stopped unexpectedly", zap.Error(err)) + } + }() + return nil + } + } + } +} + // buildClients initializes the storage clients for persisted operations and router config. func (r *Router) buildClients(ctx context.Context) error { registry := r.providerRegistry @@ -1722,6 +1796,14 @@ func (r *Router) Shutdown(ctx context.Context) error { }) } + if r.codeModeServer != nil { + wg.Go(func() { + if subErr := r.codeModeServer.Stop(ctx); subErr != nil { + err.Append(fmt.Errorf("failed to shutdown code mode MCP server: %w", subErr)) + } + }) + } + if r.connectRPCServer != nil { wg.Go(func() { if subErr := r.connectRPCServer.Stop(ctx); subErr != nil { diff --git a/router/core/router_config.go b/router/core/router_config.go index 9f4b0bf84c..688cd61b22 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -6,6 +6,7 @@ import ( "time" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + codemodeserver "github.com/wundergraph/cosmo/router/internal/codemode/server" "github.com/wundergraph/cosmo/router/internal/graphqlmetrics" "github.com/wundergraph/cosmo/router/internal/persistedoperation" "github.com/wundergraph/cosmo/router/internal/persistedoperation/pqlmanifest" @@ -113,6 +114,7 @@ type Config struct { retryOptions retrytransport.RetryOptions redisClient rd.RDCloser mcpServer *mcpserver.GraphQLSchemaServer + codeModeServer *codemodeserver.Server connectRPCServer *connectrpc.Server processStartTime time.Time developmentMode bool diff --git a/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go new file mode 100644 index 0000000000..c1fb97bbbd --- /dev/null +++ b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go @@ -0,0 +1,451 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc (unknown) +// source: wg/cosmo/code_mode/yoko/v1/yoko.proto + +package yokov1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type OperationKind int32 + +const ( + OperationKind_OPERATION_KIND_UNSPECIFIED OperationKind = 0 + OperationKind_OPERATION_KIND_QUERY OperationKind = 1 + OperationKind_OPERATION_KIND_MUTATION OperationKind = 2 +) + +// Enum value maps for OperationKind. +var ( + OperationKind_name = map[int32]string{ + 0: "OPERATION_KIND_UNSPECIFIED", + 1: "OPERATION_KIND_QUERY", + 2: "OPERATION_KIND_MUTATION", + } + OperationKind_value = map[string]int32{ + "OPERATION_KIND_UNSPECIFIED": 0, + "OPERATION_KIND_QUERY": 1, + "OPERATION_KIND_MUTATION": 2, + } +) + +func (x OperationKind) Enum() *OperationKind { + p := new(OperationKind) + *p = x + return p +} + +func (x OperationKind) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (OperationKind) Descriptor() protoreflect.EnumDescriptor { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes[0].Descriptor() +} + +func (OperationKind) Type() protoreflect.EnumType { + return &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes[0] +} + +func (x OperationKind) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use OperationKind.Descriptor instead. +func (OperationKind) EnumDescriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{0} +} + +type IndexRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The supergraph SDL to index. Sent in full on every Index call; + // Yoko deduplicates internally and is free to short-circuit when + // the SDL is already known. + SchemaSdl string `protobuf:"bytes,1,opt,name=schema_sdl,json=schemaSdl,proto3" json:"schema_sdl,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IndexRequest) Reset() { + *x = IndexRequest{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IndexRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IndexRequest) ProtoMessage() {} + +func (x *IndexRequest) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IndexRequest.ProtoReflect.Descriptor instead. +func (*IndexRequest) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{0} +} + +func (x *IndexRequest) GetSchemaSdl() string { + if x != nil { + return x.SchemaSdl + } + return "" +} + +type IndexResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Opaque, Yoko-assigned identifier for this schema. Stable for as + // long as Yoko retains the index. Subsequent Search calls pass this + // back instead of the full SDL. Idempotent: the same SDL returns + // the same schema_id. + SchemaId string `protobuf:"bytes,1,opt,name=schema_id,json=schemaId,proto3" json:"schema_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IndexResponse) Reset() { + *x = IndexResponse{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IndexResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IndexResponse) ProtoMessage() {} + +func (x *IndexResponse) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IndexResponse.ProtoReflect.Descriptor instead. +func (*IndexResponse) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{1} +} + +func (x *IndexResponse) GetSchemaId() string { + if x != nil { + return x.SchemaId + } + return "" +} + +type SearchRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Batch of natural-language prompts. Bounded at 20 by the host. + Prompts []string `protobuf:"bytes,1,rep,name=prompts,proto3" json:"prompts,omitempty"` + // Identifier returned by a prior Index call. If Yoko no longer + // recognizes the id (e.g. eviction, restart), it MUST return the + // Connect error code NOT_FOUND; the router re-indexes and retries + // the call exactly once. + SchemaId string `protobuf:"bytes,2,opt,name=schema_id,json=schemaId,proto3" json:"schema_id,omitempty"` + // Opaque MCP session ID for telemetry correlation only. + // Yoko MUST NOT use this for stateful behavior — sessions are owned + // by the router. + SessionId string `protobuf:"bytes,3,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchRequest) Reset() { + *x = SearchRequest{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchRequest) ProtoMessage() {} + +func (x *SearchRequest) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchRequest.ProtoReflect.Descriptor instead. +func (*SearchRequest) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{2} +} + +func (x *SearchRequest) GetPrompts() []string { + if x != nil { + return x.Prompts + } + return nil +} + +func (x *SearchRequest) GetSchemaId() string { + if x != nil { + return x.SchemaId + } + return "" +} + +func (x *SearchRequest) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +type SearchResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Operations across all prompts, already deduplicated and ranked. + // Order is significant: earlier entries rank higher and are preferred + // when bundle truncation drops from the tail. + Operations []*GeneratedOperation `protobuf:"bytes,1,rep,name=operations,proto3" json:"operations,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchResponse) Reset() { + *x = SearchResponse{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchResponse) ProtoMessage() {} + +func (x *SearchResponse) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchResponse.ProtoReflect.Descriptor instead. +func (*SearchResponse) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{3} +} + +func (x *SearchResponse) GetOperations() []*GeneratedOperation { + if x != nil { + return x.Operations + } + return nil +} + +type GeneratedOperation struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Suggested operation name (camelCase preferred). The host applies + // its own identifier normalization and in-session collision-suffix + // logic on top of this — see §6. + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // GraphQL operation body (query or mutation source text). + Body string `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"` + // Operation kind. Subscriptions are out of scope; if Yoko returns + // one, the host drops it with a single warn log. + Kind OperationKind `protobuf:"varint,3,opt,name=kind,proto3,enum=wundergraph.cosmo.code_mode.yoko.v1.OperationKind" json:"kind,omitempty"` + // Human-readable description, surfaced as JSDoc on the typed + // `tools.` signature in the rendered bundle. + Description string `protobuf:"bytes,4,opt,name=description,proto3" json:"description,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeneratedOperation) Reset() { + *x = GeneratedOperation{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeneratedOperation) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeneratedOperation) ProtoMessage() {} + +func (x *GeneratedOperation) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeneratedOperation.ProtoReflect.Descriptor instead. +func (*GeneratedOperation) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{4} +} + +func (x *GeneratedOperation) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *GeneratedOperation) GetBody() string { + if x != nil { + return x.Body + } + return "" +} + +func (x *GeneratedOperation) GetKind() OperationKind { + if x != nil { + return x.Kind + } + return OperationKind_OPERATION_KIND_UNSPECIFIED +} + +func (x *GeneratedOperation) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + +var File_wg_cosmo_code_mode_yoko_v1_yoko_proto protoreflect.FileDescriptor + +const file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc = "" + + "\n" + + "%wg/cosmo/code_mode/yoko/v1/yoko.proto\x12#wundergraph.cosmo.code_mode.yoko.v1\"-\n" + + "\fIndexRequest\x12\x1d\n" + + "\n" + + "schema_sdl\x18\x01 \x01(\tR\tschemaSdl\",\n" + + "\rIndexResponse\x12\x1b\n" + + "\tschema_id\x18\x01 \x01(\tR\bschemaId\"e\n" + + "\rSearchRequest\x12\x18\n" + + "\aprompts\x18\x01 \x03(\tR\aprompts\x12\x1b\n" + + "\tschema_id\x18\x02 \x01(\tR\bschemaId\x12\x1d\n" + + "\n" + + "session_id\x18\x03 \x01(\tR\tsessionId\"i\n" + + "\x0eSearchResponse\x12W\n" + + "\n" + + "operations\x18\x01 \x03(\v27.wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperationR\n" + + "operations\"\xa6\x01\n" + + "\x12GeneratedOperation\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n" + + "\x04body\x18\x02 \x01(\tR\x04body\x12F\n" + + "\x04kind\x18\x03 \x01(\x0e22.wundergraph.cosmo.code_mode.yoko.v1.OperationKindR\x04kind\x12 \n" + + "\vdescription\x18\x04 \x01(\tR\vdescription*f\n" + + "\rOperationKind\x12\x1e\n" + + "\x1aOPERATION_KIND_UNSPECIFIED\x10\x00\x12\x18\n" + + "\x14OPERATION_KIND_QUERY\x10\x01\x12\x1b\n" + + "\x17OPERATION_KIND_MUTATION\x10\x022\xf0\x01\n" + + "\vYokoService\x12n\n" + + "\x05Index\x121.wundergraph.cosmo.code_mode.yoko.v1.IndexRequest\x1a2.wundergraph.cosmo.code_mode.yoko.v1.IndexResponse\x12q\n" + + "\x06Search\x122.wundergraph.cosmo.code_mode.yoko.v1.SearchRequest\x1a3.wundergraph.cosmo.code_mode.yoko.v1.SearchResponseB\xb2\x02\n" + + "'com.wundergraph.cosmo.code_mode.yoko.v1B\tYokoProtoP\x01ZOgithub.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1;yokov1\xa2\x02\x04WCCY\xaa\x02\"Wundergraph.Cosmo.CodeMode.Yoko.V1\xca\x02\"Wundergraph\\Cosmo\\CodeMode\\Yoko\\V1\xe2\x02.Wundergraph\\Cosmo\\CodeMode\\Yoko\\V1\\GPBMetadata\xea\x02&Wundergraph::Cosmo::CodeMode::Yoko::V1b\x06proto3" + +var ( + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescOnce sync.Once + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescData []byte +) + +func file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP() []byte { + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescOnce.Do(func() { + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc), len(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc))) + }) + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescData +} + +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_goTypes = []any{ + (OperationKind)(0), // 0: wundergraph.cosmo.code_mode.yoko.v1.OperationKind + (*IndexRequest)(nil), // 1: wundergraph.cosmo.code_mode.yoko.v1.IndexRequest + (*IndexResponse)(nil), // 2: wundergraph.cosmo.code_mode.yoko.v1.IndexResponse + (*SearchRequest)(nil), // 3: wundergraph.cosmo.code_mode.yoko.v1.SearchRequest + (*SearchResponse)(nil), // 4: wundergraph.cosmo.code_mode.yoko.v1.SearchResponse + (*GeneratedOperation)(nil), // 5: wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation +} +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_depIdxs = []int32{ + 5, // 0: wundergraph.cosmo.code_mode.yoko.v1.SearchResponse.operations:type_name -> wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation + 0, // 1: wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation.kind:type_name -> wundergraph.cosmo.code_mode.yoko.v1.OperationKind + 1, // 2: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index:input_type -> wundergraph.cosmo.code_mode.yoko.v1.IndexRequest + 3, // 3: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search:input_type -> wundergraph.cosmo.code_mode.yoko.v1.SearchRequest + 2, // 4: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index:output_type -> wundergraph.cosmo.code_mode.yoko.v1.IndexResponse + 4, // 5: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search:output_type -> wundergraph.cosmo.code_mode.yoko.v1.SearchResponse + 4, // [4:6] is the sub-list for method output_type + 2, // [2:4] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_wg_cosmo_code_mode_yoko_v1_yoko_proto_init() } +func file_wg_cosmo_code_mode_yoko_v1_yoko_proto_init() { + if File_wg_cosmo_code_mode_yoko_v1_yoko_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc), len(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc)), + NumEnums: 1, + NumMessages: 5, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_goTypes, + DependencyIndexes: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_depIdxs, + EnumInfos: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes, + MessageInfos: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes, + }.Build() + File_wg_cosmo_code_mode_yoko_v1_yoko_proto = out.File + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_goTypes = nil + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_depIdxs = nil +} diff --git a/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go new file mode 100644 index 0000000000..1e157644aa --- /dev/null +++ b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go @@ -0,0 +1,142 @@ +// Code generated by protoc-gen-connect-go. DO NOT EDIT. +// +// Source: wg/cosmo/code_mode/yoko/v1/yoko.proto + +package yokov1connect + +import ( + connect "connectrpc.com/connect" + context "context" + errors "errors" + v1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + http "net/http" + strings "strings" +) + +// This is a compile-time assertion to ensure that this generated file and the connect package are +// compatible. If you get a compiler error that this constant is not defined, this code was +// generated with a version of connect newer than the one compiled into your binary. You can fix the +// problem by either regenerating this code with an older version of connect or updating the connect +// version compiled into your binary. +const _ = connect.IsAtLeastVersion1_13_0 + +const ( + // YokoServiceName is the fully-qualified name of the YokoService service. + YokoServiceName = "wundergraph.cosmo.code_mode.yoko.v1.YokoService" +) + +// These constants are the fully-qualified names of the RPCs defined in this package. They're +// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route. +// +// Note that these are different from the fully-qualified method names used by +// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to +// reflection-formatted method names, remove the leading slash and convert the remaining slash to a +// period. +const ( + // YokoServiceIndexProcedure is the fully-qualified name of the YokoService's Index RPC. + YokoServiceIndexProcedure = "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/Index" + // YokoServiceSearchProcedure is the fully-qualified name of the YokoService's Search RPC. + YokoServiceSearchProcedure = "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/Search" +) + +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. +var ( + yokoServiceServiceDescriptor = v1.File_wg_cosmo_code_mode_yoko_v1_yoko_proto.Services().ByName("YokoService") + yokoServiceIndexMethodDescriptor = yokoServiceServiceDescriptor.Methods().ByName("Index") + yokoServiceSearchMethodDescriptor = yokoServiceServiceDescriptor.Methods().ByName("Search") +) + +// YokoServiceClient is a client for the wundergraph.cosmo.code_mode.yoko.v1.YokoService service. +type YokoServiceClient interface { + Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) + Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) +} + +// NewYokoServiceClient constructs a client for the wundergraph.cosmo.code_mode.yoko.v1.YokoService +// service. By default, it uses the Connect protocol with the binary Protobuf Codec, asks for +// gzipped responses, and sends uncompressed requests. To use the gRPC or gRPC-Web protocols, supply +// the connect.WithGRPC() or connect.WithGRPCWeb() options. +// +// The URL supplied here should be the base URL for the Connect or gRPC server (for example, +// http://api.acme.com or https://acme.com/grpc). +func NewYokoServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) YokoServiceClient { + baseURL = strings.TrimRight(baseURL, "/") + return &yokoServiceClient{ + index: connect.NewClient[v1.IndexRequest, v1.IndexResponse]( + httpClient, + baseURL+YokoServiceIndexProcedure, + connect.WithSchema(yokoServiceIndexMethodDescriptor), + connect.WithClientOptions(opts...), + ), + search: connect.NewClient[v1.SearchRequest, v1.SearchResponse]( + httpClient, + baseURL+YokoServiceSearchProcedure, + connect.WithSchema(yokoServiceSearchMethodDescriptor), + connect.WithClientOptions(opts...), + ), + } +} + +// yokoServiceClient implements YokoServiceClient. +type yokoServiceClient struct { + index *connect.Client[v1.IndexRequest, v1.IndexResponse] + search *connect.Client[v1.SearchRequest, v1.SearchResponse] +} + +// Index calls wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index. +func (c *yokoServiceClient) Index(ctx context.Context, req *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) { + return c.index.CallUnary(ctx, req) +} + +// Search calls wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search. +func (c *yokoServiceClient) Search(ctx context.Context, req *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) { + return c.search.CallUnary(ctx, req) +} + +// YokoServiceHandler is an implementation of the wundergraph.cosmo.code_mode.yoko.v1.YokoService +// service. +type YokoServiceHandler interface { + Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) + Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) +} + +// NewYokoServiceHandler builds an HTTP handler from the service implementation. It returns the path +// on which to mount the handler and the handler itself. +// +// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf +// and JSON codecs. They also support gzip compression. +func NewYokoServiceHandler(svc YokoServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + yokoServiceIndexHandler := connect.NewUnaryHandler( + YokoServiceIndexProcedure, + svc.Index, + connect.WithSchema(yokoServiceIndexMethodDescriptor), + connect.WithHandlerOptions(opts...), + ) + yokoServiceSearchHandler := connect.NewUnaryHandler( + YokoServiceSearchProcedure, + svc.Search, + connect.WithSchema(yokoServiceSearchMethodDescriptor), + connect.WithHandlerOptions(opts...), + ) + return "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case YokoServiceIndexProcedure: + yokoServiceIndexHandler.ServeHTTP(w, r) + case YokoServiceSearchProcedure: + yokoServiceSearchHandler.ServeHTTP(w, r) + default: + http.NotFound(w, r) + } + }) +} + +// UnimplementedYokoServiceHandler returns CodeUnimplemented from all methods. +type UnimplementedYokoServiceHandler struct{} + +func (UnimplementedYokoServiceHandler) Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index is not implemented")) +} + +func (UnimplementedYokoServiceHandler) Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search is not implemented")) +} diff --git a/router/go.mod b/router/go.mod index c2604da4a6..499f3b8aad 100644 --- a/router/go.mod +++ b/router/go.mod @@ -80,6 +80,7 @@ require ( github.com/posthog/posthog-go v1.5.5 github.com/pquerna/cachecontrol v0.2.0 github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 + github.com/tdewolff/parse/v2 v2.8.12 github.com/tonglil/opentelemetry-go-datadog-propagator v0.1.3 github.com/wundergraph/astjson v1.1.0 github.com/wundergraph/go-arena v1.1.0 @@ -91,6 +92,8 @@ require ( golang.org/x/time v0.9.0 ) +require github.com/tetratelabs/wazero v1.9.0 // indirect + require ( github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/benbjohnson/clock v1.3.0 // indirect @@ -107,6 +110,8 @@ require ( github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/docker-credential-helpers v0.9.3 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/evanw/esbuild v0.27.3 + github.com/fastschema/qjs v0.0.6 github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/frankban/quicktest v1.14.6 // indirect diff --git a/router/go.sum b/router/go.sum index 561cbf94cd..7e82879a7d 100644 --- a/router/go.sum +++ b/router/go.sum @@ -73,8 +73,12 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/evanw/esbuild v0.27.3 h1:dH/to9tBKybig6hl25hg4SKIWP7U8COdJKbGEwnUkmU= +github.com/evanw/esbuild v0.27.3/go.mod h1:D2vIQZqV/vIf/VRHtViaUtViZmG7o+kKmlBfVQuRi48= github.com/expr-lang/expr v1.17.7 h1:Q0xY/e/2aCIp8g9s/LGvMDCC5PxYlvHgDZRQ4y16JX8= github.com/expr-lang/expr v1.17.7/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/fastschema/qjs v0.0.6 h1:C45KMmQMd21UwsUAmQHxUxiWOfzwTg1GJW0DA0AbFEE= +github.com/fastschema/qjs v0.0.6/go.mod h1:bbg36wxXnx8g0FdKIe5+nCubrQvHa7XEVWqUptjHt/A= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= @@ -299,6 +303,12 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tdewolff/parse/v2 v2.8.12 h1:5BBjfaCv482v3nltlS0u6wH1xJaxjR6ofDrWttNvROg= +github.com/tdewolff/parse/v2 v2.8.12/go.mod h1:Hwlni2tiVNKyzR1o6nUs4FOF07URA+JLBLd6dlIXYqo= +github.com/tdewolff/test v1.0.11 h1:FdLbwQVHxqG16SlkGveC0JVyrJN62COWTRyUFzfbtBE= +github.com/tdewolff/test v1.0.11/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/router/internal/codemode/calltrace/calltrace.go b/router/internal/codemode/calltrace/calltrace.go new file mode 100644 index 0000000000..a6d3286e3b --- /dev/null +++ b/router/internal/codemode/calltrace/calltrace.go @@ -0,0 +1,92 @@ +package calltrace + +import ( + "encoding/json" + "os" + "sync" + "time" +) + +type Recorder interface { + RecordRequest(toolName string, body []byte) + RecordResponse(toolName string, body []byte) +} + +type Record struct { + ToolName string `json:"tool_name"` + Timestamp time.Time `json:"timestamp"` + Body json.RawMessage `json:"body"` +} + +type NopRecorder struct{} + +func (NopRecorder) RecordRequest(string, []byte) {} +func (NopRecorder) RecordResponse(string, []byte) {} + +type FileRecorder struct { + path string + now func() time.Time + mu sync.Mutex +} + +type Option func(*FileRecorder) + +func WithNow(now func() time.Time) Option { + return func(r *FileRecorder) { + if now != nil { + r.now = now + } + } +} + +func NewFileRecorder(path string, opts ...Option) *FileRecorder { + recorder := &FileRecorder{ + path: path, + now: time.Now, + } + for _, opt := range opts { + opt(recorder) + } + return recorder +} + +func (r *FileRecorder) RecordRequest(toolName string, body []byte) { + r.record(toolName, body) +} + +func (r *FileRecorder) RecordResponse(toolName string, body []byte) { + r.record(toolName, body) +} + +func (r *FileRecorder) record(toolName string, body []byte) { + if r == nil || r.path == "" { + return + } + line, err := json.Marshal(Record{ + ToolName: toolName, + Timestamp: r.now(), + Body: json.RawMessage(body), + }) + if err != nil { + return + } + line = append(line, '\n') + + r.mu.Lock() + defer r.mu.Unlock() + file, err := os.OpenFile(r.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + return + } + defer file.Close() + _, _ = file.Write(line) +} + +func Enabled(recorder Recorder) bool { + switch recorder.(type) { + case nil, NopRecorder, *NopRecorder: + return false + default: + return true + } +} diff --git a/router/internal/codemode/calltrace/calltrace_test.go b/router/internal/codemode/calltrace/calltrace_test.go new file mode 100644 index 0000000000..b9af2315ac --- /dev/null +++ b/router/internal/codemode/calltrace/calltrace_test.go @@ -0,0 +1,51 @@ +package calltrace + +import ( + "bufio" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFileRecorderWritesRequestAndResponseJSONL(t *testing.T) { + path := filepath.Join(t.TempDir(), "call-trace.jsonl") + now := time.Date(2026, 5, 4, 10, 30, 0, 0, time.UTC) + recorder := NewFileRecorder(path, WithNow(func() time.Time { return now })) + + recorder.RecordRequest("code_mode_run_js", []byte(`{"source":"async () => 1"}`)) + recorder.RecordResponse("code_mode_run_js", []byte(`{"content":[{"type":"text","text":"1"}]}`)) + + file, err := os.Open(path) + require.NoError(t, err) + defer file.Close() + + var got []Record + scanner := bufio.NewScanner(file) + for scanner.Scan() { + var record Record + require.NoError(t, json.Unmarshal(scanner.Bytes(), &record)) + got = append(got, record) + } + require.NoError(t, scanner.Err()) + assert.Equal(t, []Record{ + { + ToolName: "code_mode_run_js", + Timestamp: now, + Body: json.RawMessage(`{"source":"async () =\u003e 1"}`), + }, + { + ToolName: "code_mode_run_js", + Timestamp: now, + Body: json.RawMessage(`{"content":[{"type":"text","text":"1"}]}`), + }, + }, got) +} + +func TestNopRecorderIsDisabled(t *testing.T) { + assert.Equal(t, false, Enabled(NopRecorder{})) +} diff --git a/router/internal/codemode/deps.go b/router/internal/codemode/deps.go new file mode 100644 index 0000000000..ed36aea04c --- /dev/null +++ b/router/internal/codemode/deps.go @@ -0,0 +1,8 @@ +//go:build tools + +package codemode + +import ( + _ "github.com/evanw/esbuild/pkg/api" + _ "github.com/fastschema/qjs" +) diff --git a/router/internal/codemode/harness/envelope.go b/router/internal/codemode/harness/envelope.go new file mode 100644 index 0000000000..f4bfca7171 --- /dev/null +++ b/router/internal/codemode/harness/envelope.go @@ -0,0 +1,203 @@ +package harness + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + "unicode/utf8" + + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +const defaultMaxResultBytes = 32 << 10 +const previewBytes = 1 << 10 + +type ErrorEnvelope = sandbox.ErrorEnvelope + +// ResultEnvelope is the MCP-facing tool-result body for code_mode_run_js. +// +// Wire shape: +// - result is always present (null if the agent threw). +// - truncated is omitted on the wire when false (only signals a non-default state). +// - error is omitted on the wire when nil (only present on the throw path). +type ResultEnvelope struct { + Result json.RawMessage `json:"result"` + Truncated bool `json:"truncated,omitempty"` + Error *ErrorEnvelope `json:"error,omitempty"` +} + +func BuildEnvelope(sandboxResult sandbox.ExecuteResult, maxResultBytes int) (ResultEnvelope, error) { + if maxResultBytes <= 0 { + maxResultBytes = defaultMaxResultBytes + } + if !sandboxResult.OK { + return ResultEnvelope{ + Result: json.RawMessage("null"), + Truncated: false, + Error: cloneErrorEnvelope(sandboxResult.Error), + }, nil + } + if len(sandboxResult.Result) <= maxResultBytes { + return ResultEnvelope{Result: sandboxResult.Result, Truncated: false, Error: nil}, nil + } + + truncated, ok, err := structurallyTruncate(sandboxResult.Result, maxResultBytes) + if err != nil { + return ResultEnvelope{}, err + } + if ok { + return ResultEnvelope{Result: truncated, Truncated: true, Error: nil}, nil + } + fallback, err := previewEnvelope(sandboxResult.Result) + if err != nil { + return ResultEnvelope{}, err + } + return ResultEnvelope{Result: fallback, Truncated: true, Error: nil}, nil +} + +func cloneErrorEnvelope(err *ErrorEnvelope) *ErrorEnvelope { + if err == nil { + return nil + } + return &ErrorEnvelope{ + Name: err.Name, + Message: err.Message, + Stack: err.Stack, + Cause: cloneErrorEnvelope(err.Cause), + } +} + +func structurallyTruncate(raw json.RawMessage, maxBytes int) (json.RawMessage, bool, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, false, fmt.Errorf("empty JSON result") + } + switch trimmed[0] { + case '[': + items, err := splitJSONArray(trimmed) + if err != nil { + return nil, false, err + } + for keep := len(items); keep >= 0; keep-- { + body := joinJSON('[', ']', items[:keep]) + if len(body) <= maxBytes { + return body, true, nil + } + } + case '{': + fields, err := splitJSONObject(trimmed) + if err != nil { + return nil, false, err + } + for keep := len(fields); keep >= 0; keep-- { + body := joinJSON('{', '}', fields[:keep]) + if len(body) <= maxBytes { + return body, true, nil + } + } + } + return nil, false, nil +} + +func splitJSONArray(raw []byte) ([]json.RawMessage, error) { + if !json.Valid(raw) { + return nil, fmt.Errorf("invalid JSON result") + } + inner := bytes.TrimSpace(raw[1 : len(raw)-1]) + if len(inner) == 0 { + return nil, nil + } + return splitTopLevel(inner), nil +} + +func splitJSONObject(raw []byte) ([]json.RawMessage, error) { + if !json.Valid(raw) { + return nil, fmt.Errorf("invalid JSON result") + } + inner := bytes.TrimSpace(raw[1 : len(raw)-1]) + if len(inner) == 0 { + return nil, nil + } + return splitTopLevel(inner), nil +} + +func splitTopLevel(raw []byte) []json.RawMessage { + parts := make([]json.RawMessage, 0) + start := 0 + depth := 0 + inString := false + escaped := false + for i, b := range raw { + if inString { + if escaped { + escaped = false + } else if b == '\\' { + escaped = true + } else if b == '"' { + inString = false + } + continue + } + switch b { + case '"': + inString = true + case '[', '{': + depth++ + case ']', '}': + depth-- + case ',': + if depth == 0 { + parts = append(parts, bytes.TrimSpace(raw[start:i])) + start = i + 1 + } + } + } + parts = append(parts, bytes.TrimSpace(raw[start:])) + return parts +} + +func joinJSON(open byte, close byte, parts []json.RawMessage) json.RawMessage { + var b strings.Builder + b.WriteByte(open) + for i, part := range parts { + if i > 0 { + b.WriteByte(',') + } + b.Write(bytes.TrimSpace(part)) + } + b.WriteByte(close) + return json.RawMessage(b.String()) +} + +func previewEnvelope(raw json.RawMessage) (json.RawMessage, error) { + preview := string(raw) + var value string + if err := json.Unmarshal(raw, &value); err == nil { + preview = value + } + body, err := json.Marshal(struct { + Truncated bool `json:"__truncated"` + OriginalSize int `json:"originalSize"` + Preview string `json:"preview"` + }{ + Truncated: true, + OriginalSize: len(raw), + Preview: firstUTF8Bytes(preview, previewBytes), + }) + if err != nil { + return nil, err + } + return body, nil +} + +func firstUTF8Bytes(s string, limit int) string { + if len(s) <= limit { + return s + } + cut := limit + for cut > 0 && !utf8.ValidString(s[:cut]) { + cut-- + } + return s[:cut] +} diff --git a/router/internal/codemode/harness/envelope_test.go b/router/internal/codemode/harness/envelope_test.go new file mode 100644 index 0000000000..9c6ce7f3a5 --- /dev/null +++ b/router/internal/codemode/harness/envelope_test.go @@ -0,0 +1,61 @@ +package harness + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +func TestBuildEnvelopePassesThroughSmallResult(t *testing.T) { + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: true, Result: raw(`{"ok":true}`)}, 32<<10) + require.NoError(t, err) + + assert.Equal(t, ResultEnvelope{Result: raw(`{"ok":true}`), Truncated: false, Error: nil}, got) +} + +func TestBuildEnvelopeTruncatesTopLevelArray(t *testing.T) { + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: true, Result: raw(`[{"id":1},{"id":2},{"id":3}]`)}, len(`[{"id":1},{"id":2}]`)) + require.NoError(t, err) + + assert.Equal(t, ResultEnvelope{Result: raw(`[{"id":1},{"id":2}]`), Truncated: true, Error: nil}, got) +} + +func TestBuildEnvelopeTruncatesTopLevelObject(t *testing.T) { + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: true, Result: raw(`{"a":1,"b":2,"c":3}`)}, len(`{"a":1,"b":2}`)) + require.NoError(t, err) + + assert.Equal(t, ResultEnvelope{Result: raw(`{"a":1,"b":2}`), Truncated: true, Error: nil}, got) +} + +func TestBuildEnvelopeFallsBackToPreviewForHugeScalar(t *testing.T) { + value := strings.Repeat("a", 2048) + body, err := json.Marshal(value) + require.NoError(t, err) + + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: true, Result: body}, 128) + require.NoError(t, err) + + var preview struct { + Truncated bool `json:"__truncated"` + OriginalSize int `json:"originalSize"` + Preview string `json:"preview"` + } + require.NoError(t, json.Unmarshal(got.Result, &preview)) + assert.Equal(t, true, got.Truncated) + assert.Equal(t, true, preview.Truncated) + assert.Equal(t, len(body), preview.OriginalSize) + assert.Equal(t, strings.Repeat("a", 1024), preview.Preview) +} + +func TestBuildEnvelopeCopiesSandboxError(t *testing.T) { + sandboxErr := &sandbox.ErrorEnvelope{Name: "Error", Message: "boom", Stack: "stack"} + + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: false, Error: sandboxErr}, 32<<10) + require.NoError(t, err) + + assert.Equal(t, ResultEnvelope{Result: raw(`null`), Truncated: false, Error: sandboxErr}, got) +} diff --git a/router/internal/codemode/harness/pipeline.go b/router/internal/codemode/harness/pipeline.go new file mode 100644 index 0000000000..a11a607a72 --- /dev/null +++ b/router/internal/codemode/harness/pipeline.go @@ -0,0 +1,127 @@ +package harness + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +const defaultMaxInputBytes = 64 << 10 + +type sandboxExecutor interface { + Execute(ctx context.Context, req sandbox.ExecuteRequest) (sandbox.ExecuteResult, error) +} + +type Executor interface { + Execute(ctx context.Context, req PipelineRequest) (PipelineResponse, error) +} + +type Pipeline struct { + Sandbox *sandbox.Sandbox + MaxInputBytes int + MaxResultBytes int + + executor sandboxExecutor +} + +type PipelineRequest struct { + SessionID string + ToolNames []string + Source string + RequestHeaders http.Header + ApprovalGate sandbox.ApprovalGate +} + +type PipelineResponse struct { + Envelope ResultEnvelope + Encoded []byte + Diagnostics []Diagnostic +} + +func (p *Pipeline) Execute(ctx context.Context, req PipelineRequest) (PipelineResponse, error) { + maxInputBytes := p.MaxInputBytes + if maxInputBytes <= 0 { + maxInputBytes = defaultMaxInputBytes + } + + // Raw-source guard rejects oversized input before esbuild parses it. The + // same limit applies post-transpile below because generated JS can differ + // slightly from source size. + if len(req.Source) > maxInputBytes { + return p.errorResponse(&ErrorEnvelope{ + Name: "InputTooLarge", + Message: fmt.Sprintf("input size %d bytes exceeds limit %d bytes", len(req.Source), maxInputBytes), + Stack: "", + }, nil) + } + + transpiled, err := Transpile(req.Source) + if err != nil { + return p.errorResponse(&ErrorEnvelope{Name: "TranspileError", Message: err.Error(), Stack: ""}, transpiled.Diagnostics) + } + + if len(transpiled.JS) > maxInputBytes { + return p.errorResponse(&ErrorEnvelope{ + Name: "InputTooLarge", + Message: fmt.Sprintf("input size %d bytes exceeds limit %d bytes", len(transpiled.JS), maxInputBytes), + Stack: "", + }, nil) + } + + if err := ShapeCheck(transpiled.JS); err != nil { + return p.errorResponse(&ErrorEnvelope{Name: "ShapeCheck", Message: err.Error(), Stack: ""}, nil) + } + + executor, err := p.sandboxExecutor() + if err != nil { + return PipelineResponse{}, err + } + sandboxResult, err := executor.Execute(ctx, sandbox.ExecuteRequest{ + SessionID: req.SessionID, + ToolNames: req.ToolNames, + WrappedJS: transpiled.JS, + SourceMap: transpiled.SourceMap, + RequestHeaders: req.RequestHeaders, + ApprovalGate: req.ApprovalGate, + }) + if err != nil { + return PipelineResponse{}, err + } + + envelope, err := BuildEnvelope(sandboxResult, p.MaxResultBytes) + if err != nil { + return PipelineResponse{}, err + } + encoded, err := json.Marshal(envelope) + if err != nil { + return PipelineResponse{}, err + } + return PipelineResponse{Envelope: envelope, Encoded: encoded}, nil +} + +func (p *Pipeline) sandboxExecutor() (sandboxExecutor, error) { + if p.executor != nil { + return p.executor, nil + } + if p.Sandbox == nil { + return nil, errors.New("code mode: pipeline sandbox is nil") + } + return p.Sandbox, nil +} + +func (p *Pipeline) errorResponse(errEnv *ErrorEnvelope, diagnostics []Diagnostic) (PipelineResponse, error) { + envelope := ResultEnvelope{ + Result: json.RawMessage("null"), + Truncated: false, + Error: errEnv, + } + encoded, err := json.Marshal(envelope) + if err != nil { + return PipelineResponse{}, err + } + return PipelineResponse{Envelope: envelope, Encoded: encoded, Diagnostics: diagnostics}, nil +} diff --git a/router/internal/codemode/harness/pipeline_test.go b/router/internal/codemode/harness/pipeline_test.go new file mode 100644 index 0000000000..981d943c3c --- /dev/null +++ b/router/internal/codemode/harness/pipeline_test.go @@ -0,0 +1,144 @@ +package harness + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +type fakeExecutor struct { + calls int + result sandbox.ExecuteResult + err error +} + +func (f *fakeExecutor) Execute(ctx context.Context, req sandbox.ExecuteRequest) (sandbox.ExecuteResult, error) { + f.calls++ + return f.result, f.err +} + +func TestPipelineShapeCheckFailureShortCircuits(t *testing.T) { + fake := &fakeExecutor{} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `() => 1`}) + require.NoError(t, err) + + assert.Equal(t, 0, fake.calls) + require.NotNil(t, got.Envelope.Error) + assert.Equal(t, "ShapeCheck", got.Envelope.Error.Name) + assert.Equal(t, `code mode: source must be a single async-arrow root (got: missing async modifier)`, got.Envelope.Error.Message) + assert.Empty(t, got.Diagnostics) + assert.NotEmpty(t, got.Encoded) +} + +func TestPipelineTopLevelAwaitFailsAtTranspile(t *testing.T) { + fake := &fakeExecutor{} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `await tools.getUser({})`}) + require.NoError(t, err) + + assert.Equal(t, 0, fake.calls) + require.NotNil(t, got.Envelope.Error) + assert.Equal(t, "TranspileError", got.Envelope.Error.Name) + // esbuild's exact message is target-version dependent. We only assert the + // transpile-error envelope name; the full message lives in Diagnostics. + assert.NotEmpty(t, got.Diagnostics) +} + +func TestPipelineAcceptsTypeScriptSource(t *testing.T) { + fake := &fakeExecutor{result: sandbox.ExecuteResult{OK: true, Result: raw(`{"id":"1"}`)}} + pipeline := Pipeline{executor: fake} + + // TypeScript source: type annotations, optional params, type parameters. + // All three are valid TS-only syntax. Pipeline must transpile then accept. + tsInputs := []string{ + `async (x: string) => ({ id: x })`, + `async (x: string, y?: number) => ({ id: x })`, + `async (x: T) => ({ id: String(x) })`, + } + for _, in := range tsInputs { + t.Run(in, func(t *testing.T) { + fake.calls = 0 + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: in}) + require.NoError(t, err) + assert.Equal(t, 1, fake.calls, "sandbox should be invoked") + assert.Nil(t, got.Envelope.Error, "no shape or transpile error expected") + }) + } +} + +func TestPipelineTranspileFailureReturnsDiagnostics(t *testing.T) { + fake := &fakeExecutor{} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `async () => { let x = ; }`}) + require.NoError(t, err) + + assert.Equal(t, 0, fake.calls) + require.NotNil(t, got.Envelope.Error) + assert.Equal(t, "TranspileError", got.Envelope.Error.Name) + assert.NotEmpty(t, got.Diagnostics) + assert.NotEmpty(t, got.Encoded) +} + +func TestPipelineSandboxErrorIsFoldedIntoEnvelope(t *testing.T) { + fake := &fakeExecutor{result: sandbox.ExecuteResult{ + OK: false, + Error: &sandbox.ErrorEnvelope{Name: "RuntimeError", Message: "boom", Stack: "stack"}, + }} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `async () => 1`}) + require.NoError(t, err) + + assert.Equal(t, 1, fake.calls) + require.NotNil(t, got.Envelope.Error) + assert.Equal(t, "RuntimeError", got.Envelope.Error.Name) + assert.Equal(t, false, got.Envelope.Truncated) +} + +func TestPipelineSandboxSuccessEncodesEnvelope(t *testing.T) { + fake := &fakeExecutor{result: sandbox.ExecuteResult{OK: true, Result: raw(`{"ok":true}`)}} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{ + SessionID: "session-1", + ToolNames: []string{"getUser"}, + Source: `async () => ({ ok: true })`, + RequestHeaders: http.Header{"Authorization": []string{"Bearer token"}}, + ApprovalGate: nil, + }) + require.NoError(t, err) + + assert.Equal(t, 1, fake.calls) + assert.Equal(t, ResultEnvelope{Result: raw(`{"ok":true}`), Truncated: false, Error: nil}, got.Envelope) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(got.Encoded, &decoded)) + assert.Equal(t, map[string]any{"result": map[string]any{"ok": true}}, decoded) +} + +func TestPipelineTruncationTriggers(t *testing.T) { + result, err := json.Marshal([]any{map[string]any{"id": 1}, map[string]any{"id": 2}, map[string]any{"id": 3}}) + require.NoError(t, err) + + fake := &fakeExecutor{result: sandbox.ExecuteResult{OK: true, Result: result}} + pipeline := Pipeline{executor: fake, MaxResultBytes: len(`[{"id":1},{"id":2}]`)} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `async () => []`}) + require.NoError(t, err) + + assert.Equal(t, true, got.Envelope.Truncated) + assert.Equal(t, raw(`[{"id":1},{"id":2}]`), got.Envelope.Result) +} + +func raw(s string) json.RawMessage { + return json.RawMessage(s) +} diff --git a/router/internal/codemode/harness/shape.go b/router/internal/codemode/harness/shape.go new file mode 100644 index 0000000000..56b1a9cd7a --- /dev/null +++ b/router/internal/codemode/harness/shape.go @@ -0,0 +1,97 @@ +package harness + +import ( + "errors" + "strings" + + "github.com/tdewolff/parse/v2" + "github.com/tdewolff/parse/v2/js" +) + +const shapeErrorPrefix = "code mode: source must be a single async-arrow root (got: " + +// ShapeCheck verifies that the given JavaScript source is exactly one +// top-level expression statement whose expression is an async arrow function. +// +// Input contract: ShapeCheck expects the *post-esbuild* JavaScript. TypeScript +// syntax is stripped earlier in the pipeline by Transpile (esbuild loaderTS). +// Callers must run Transpile first. +// +// Note: parse error messages from tdewolff include line/col positions for the +// post-esbuild source, NOT the original TS source the user wrote. That's +// acceptable because (a) ShapeCheck failures are structural, not character-level, +// and (b) Transpile already surfaces TS-source diagnostics for syntactic errors. +func ShapeCheck(source string) error { + if strings.TrimSpace(source) == "" { + return shapeError("empty source") + } + + ast, err := js.Parse(parse.NewInputBytes([]byte(source)), js.Options{}) + if err != nil { + return shapeError("parse failed: " + err.Error()) + } + + stmts := ast.BlockStmt.List + if len(stmts) == 0 { + return shapeError("empty source") + } + + // Detect import/export *before* the multi-statement check. Otherwise an + // input like `import x from "x"; async () => x` would report + // "multiple statements" instead of the more useful "leading import/export". + switch stmts[0].(type) { + case *js.ImportStmt, *js.ExportStmt: + return shapeError("leading import/export") + } + + if len(stmts) > 1 { + return shapeError("multiple statements") + } + + switch stmt := stmts[0].(type) { + case *js.ExprStmt: + return checkExpression(stmt.Value) + default: + return shapeError("non-arrow root") + } +} + +// checkExpression verifies the expression is an async arrow function, +// transparently unwrapping any number of redundant parentheses. +func checkExpression(expr js.IExpr) error { + for { + group, ok := expr.(*js.GroupExpr) + if !ok { + break + } + expr = group.X + } + + if isTopLevelAwait(expr) { + return shapeError("top-level await") + } + + arrow, ok := expr.(*js.ArrowFunc) + if !ok { + return shapeError("non-arrow root") + } + if !arrow.Async { + return shapeError("missing async modifier") + } + return nil +} + +// isTopLevelAwait detects `await x` used as a top-level expression. tdewolff +// parses await as a UnaryExpr with the Await operator. We surface this as a +// distinct error because it's a common model mistake worth flagging clearly. +func isTopLevelAwait(expr js.IExpr) bool { + unary, ok := expr.(*js.UnaryExpr) + if !ok { + return false + } + return unary.Op == js.AwaitToken +} + +func shapeError(reason string) error { + return errors.New(shapeErrorPrefix + reason + ")") +} diff --git a/router/internal/codemode/harness/shape_test.go b/router/internal/codemode/harness/shape_test.go new file mode 100644 index 0000000000..c6632e4389 --- /dev/null +++ b/router/internal/codemode/harness/shape_test.go @@ -0,0 +1,73 @@ +package harness + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// ShapeCheck runs on post-esbuild JavaScript. Inputs in this file are written +// as the JS that Transpile would produce — never raw TypeScript. End-to-end +// TS handling is covered by pipeline_test.go and transpile_test.go. + +func TestShapeCheckAcceptsAsyncArrowRoots(t *testing.T) { + tests := []string{ + `async () => 1`, + `async()=>1`, + `async () => { return 1; }`, + `async (x) => x`, + `async (x, y) => x + y`, + `async (x) => ({ x })`, + `(async () => 1)`, + `((async () => 1))`, + " \n\tasync () => true", + "// leading\nasync () => true", + "/* leading */ async () => true", + `async ({ id }) => id`, + `async () => await tools.getUser({ id: "1" })`, + `async () => { const rows = await Promise.all([]); return rows; }`, + } + for _, source := range tests { + t.Run(source, func(t *testing.T) { + assert.NoError(t, ShapeCheck(source)) + }) + } +} + +func TestShapeCheckRejectsNonAsyncArrowRoots(t *testing.T) { + tests := []struct { + name string + source string + want string + }{ + // Top-level await: ShapeCheck handles this defensively for the case where the + // pipeline's esbuild target is later raised to ES2022. Under today's ES2020 target, + // `await x` is rejected at Transpile and never reaches ShapeCheck — but the AST + // path still works as a unit, so we keep the test. + {name: "top-level await", source: `await tools.getUser({})`, want: `code mode: source must be a single async-arrow root (got: top-level await)`}, + // Import/export must be detected before the multi-statement check, otherwise + // `import x from "x"; async () => x` reports "multiple statements" instead. + {name: "import then arrow", source: `import x from "x"; async () => x`, want: `code mode: source must be a single async-arrow root (got: leading import/export)`}, + {name: "import alone", source: `import x from "x"`, want: `code mode: source must be a single async-arrow root (got: leading import/export)`}, + {name: "export", source: `export default async () => 1`, want: `code mode: source must be a single async-arrow root (got: leading import/export)`}, + {name: "block", source: `{ async () => 1 }`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + {name: "function declaration", source: `async function main() {}`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + {name: "non async arrow", source: `() => 1`, want: `code mode: source must be a single async-arrow root (got: missing async modifier)`}, + {name: "paren non async arrow", source: `(() => 1)`, want: `code mode: source must be a single async-arrow root (got: missing async modifier)`}, + {name: "identifier", source: `foo`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + {name: "empty", source: ` `, want: `code mode: source must be a single async-arrow root (got: empty source)`}, + {name: "comment-only", source: `// only trivia`, want: `code mode: source must be a single async-arrow root (got: empty source)`}, + {name: "async call", source: `async()`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + {name: "multiple arrows", source: `async () => 1; async () => 2`, want: `code mode: source must be a single async-arrow root (got: multiple statements)`}, + {name: "var then arrow", source: `const x = 1; async () => x`, want: `code mode: source must be a single async-arrow root (got: multiple statements)`}, + {name: "class", source: `class X {}`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ShapeCheck(tt.source) + if assert.Error(t, err) { + assert.Equal(t, tt.want, err.Error()) + } + }) + } +} diff --git a/router/internal/codemode/harness/transpile.go b/router/internal/codemode/harness/transpile.go new file mode 100644 index 0000000000..129cf7bd3c --- /dev/null +++ b/router/internal/codemode/harness/transpile.go @@ -0,0 +1,73 @@ +package harness + +import ( + "errors" + "strings" + + "github.com/evanw/esbuild/pkg/api" +) + +type TranspileResult struct { + JS string + SourceMap []byte + Diagnostics []Diagnostic +} + +type Diagnostic struct { + Text string + Line int + Column int + File string +} + +func Transpile(source string) (TranspileResult, error) { + result := api.Transform(source, api.TransformOptions{ + Loader: api.LoaderTS, + Target: api.ES2020, + Platform: api.PlatformNeutral, + Format: api.FormatDefault, + Sourcemap: api.SourceMapExternal, + Sourcefile: "agent.ts", + LogLevel: api.LogLevelSilent, + LegalComments: api.LegalCommentsNone, + Drop: api.DropDebugger, + Charset: api.CharsetASCII, + }) + + out := TranspileResult{ + JS: trimTranspiledExpression(string(result.Code)), + SourceMap: append([]byte(nil), result.Map...), + Diagnostics: diagnosticsFromMessages(result.Errors), + } + if len(result.Errors) > 0 { + return out, errors.New("transpile failed: " + strings.Join(diagnosticTexts(out.Diagnostics), "; ")) + } + return out, nil +} + +func trimTranspiledExpression(js string) string { + trimmed := strings.TrimSpace(js) + return strings.TrimSuffix(trimmed, ";") +} + +func diagnosticsFromMessages(messages []api.Message) []Diagnostic { + diagnostics := make([]Diagnostic, 0, len(messages)) + for _, message := range messages { + diagnostic := Diagnostic{Text: message.Text} + if message.Location != nil { + diagnostic.Line = message.Location.Line + diagnostic.Column = message.Location.Column + 1 + diagnostic.File = message.Location.File + } + diagnostics = append(diagnostics, diagnostic) + } + return diagnostics +} + +func diagnosticTexts(diagnostics []Diagnostic) []string { + texts := make([]string, 0, len(diagnostics)) + for _, diagnostic := range diagnostics { + texts = append(texts, diagnostic.Text) + } + return texts +} diff --git a/router/internal/codemode/harness/transpile_test.go b/router/internal/codemode/harness/transpile_test.go new file mode 100644 index 0000000000..1d9cc21597 --- /dev/null +++ b/router/internal/codemode/harness/transpile_test.go @@ -0,0 +1,61 @@ +package harness + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTranspileStripsTypeScriptAnnotations(t *testing.T) { + got, err := Transpile(`async () => { const x: string = "hi"; return x; }`) + require.NoError(t, err) + + assert.NotContains(t, got.JS, `: string`) + assert.Contains(t, got.JS, `"hi"`) + assert.False(t, strings.HasSuffix(strings.TrimSpace(got.JS), ";")) + assert.NotEmpty(t, got.SourceMap) + assert.Empty(t, got.Diagnostics) + + var sourceMap map[string]any + require.NoError(t, json.Unmarshal(got.SourceMap, &sourceMap)) + assert.Equal(t, float64(3), sourceMap["version"]) +} + +func TestTranspileTreatsTypesAsNotation(t *testing.T) { + got, err := Transpile(`async (value: { id: string }): Promise => value.id`) + require.NoError(t, err) + + assert.NotContains(t, got.JS, `Promise`) + assert.NotContains(t, got.JS, `id: string`) + assert.Contains(t, got.JS, `value.id`) +} + +func TestTranspileReportsDiagnosticsForSyntaxErrors(t *testing.T) { + got, err := Transpile(`async () => { let x = ; }`) + require.Error(t, err) + + require.NotEmpty(t, got.Diagnostics) + assert.NotEmpty(t, got.Diagnostics[0].Text) + assert.NotEqual(t, 0, got.Diagnostics[0].Line) + assert.NotEqual(t, 0, got.Diagnostics[0].Column) + assert.True(t, strings.Contains(err.Error(), got.Diagnostics[0].Text)) +} + +func TestTranspileDropsDebuggerStatement(t *testing.T) { + got, err := Transpile(`async () => { debugger; return 1; }`) + require.NoError(t, err) + + assert.NotContains(t, got.JS, "debugger", "Drop:DropDebugger should remove debugger statements") +} + +func TestTranspileEscapesNonASCII(t *testing.T) { + got, err := Transpile(`async () => "héllo"`) + require.NoError(t, err) + + // CharsetASCII tells esbuild to escape non-ASCII codepoints in string + // literals. The raw `é` byte sequence must not appear in the output. + assert.NotContains(t, got.JS, "é", "Charset:ASCII should escape non-ASCII codepoints") +} diff --git a/router/internal/codemode/observability/logging.go b/router/internal/codemode/observability/logging.go new file mode 100644 index 0000000000..20b3e24810 --- /dev/null +++ b/router/internal/codemode/observability/logging.go @@ -0,0 +1,48 @@ +package observability + +import ( + "go.uber.org/zap" +) + +func LogSessionLifecycle(logger *zap.Logger, event string, sessionID string, fields ...zap.Field) { + if logger == nil { + return + } + allFields := append([]zap.Field{ + zap.String("event", event), + zap.String("session_id", sessionID), + }, fields...) + logger.Info("code mode session lifecycle", allFields...) +} + +func LogTranspileFailure(logger *zap.Logger, sessionID string, diagnostic string) { + if logger == nil { + return + } + logger.Info("code mode transpile failure", + zap.String("session_id", sessionID), + zap.String("diagnostic", diagnostic), + ) +} + +func LogElicitationOutcome(logger *zap.Logger, sessionID string, approved bool, reason string) { + if logger == nil { + return + } + logger.Info("code mode elicitation outcome", + zap.String("session_id", sessionID), + zap.Bool("approved", approved), + zap.String("reason", reason), + ) +} + +func LogToolInvocationFailure(logger *zap.Logger, sessionID string, opName string, err error) { + if logger == nil { + return + } + logger.Info("code mode tool invocation failure", + zap.String("session_id", sessionID), + zap.String("op_name", opName), + zap.Error(err), + ) +} diff --git a/router/internal/codemode/observability/logging_test.go b/router/internal/codemode/observability/logging_test.go new file mode 100644 index 0000000000..0a883f4ab1 --- /dev/null +++ b/router/internal/codemode/observability/logging_test.go @@ -0,0 +1,43 @@ +package observability + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +func TestLoggingHelpersEmitStructuredInfoEntries(t *testing.T) { + core, observed := observer.New(zapcore.InfoLevel) + logger := zap.New(core) + + LogSessionLifecycle(logger, "created", "session-1", zap.String("storage", "memory")) + LogTranspileFailure(logger, "session-1", "Unexpected \";\"") + LogElicitationOutcome(logger, "session-1", false, "operator declined") + LogToolInvocationFailure(logger, "session-1", "getOrders", errors.New("upstream timeout")) + + entries := observed.AllUntimed() + require.Len(t, entries, 4) + assert.Equal(t, []observer.LoggedEntry{ + { + Entry: zapcore.Entry{Level: zapcore.InfoLevel, Message: "code mode session lifecycle"}, + Context: []zapcore.Field{zap.String("event", "created"), zap.String("session_id", "session-1"), zap.String("storage", "memory")}, + }, + { + Entry: zapcore.Entry{Level: zapcore.InfoLevel, Message: "code mode transpile failure"}, + Context: []zapcore.Field{zap.String("session_id", "session-1"), zap.String("diagnostic", "Unexpected \";\"")}, + }, + { + Entry: zapcore.Entry{Level: zapcore.InfoLevel, Message: "code mode elicitation outcome"}, + Context: []zapcore.Field{zap.String("session_id", "session-1"), zap.Bool("approved", false), zap.String("reason", "operator declined")}, + }, + { + Entry: zapcore.Entry{Level: zapcore.InfoLevel, Message: "code mode tool invocation failure"}, + Context: []zapcore.Field{zap.String("session_id", "session-1"), zap.String("op_name", "getOrders"), zap.Error(errors.New("upstream timeout"))}, + }, + }, entries) +} diff --git a/router/internal/codemode/observability/metrics.go b/router/internal/codemode/observability/metrics.go new file mode 100644 index 0000000000..a9aaf18428 --- /dev/null +++ b/router/internal/codemode/observability/metrics.go @@ -0,0 +1,56 @@ +package observability + +import ( + "context" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +const meterName = "wundergraph.cosmo.router.mcp.code_mode" + +type Meter struct { + executionsCounter metric.Int64Counter + durationHistogram metric.Float64Histogram +} + +func NewMeter(meterProvider metric.MeterProvider) (*Meter, error) { + if meterProvider == nil { + meterProvider = otel.GetMeterProvider() + } + meter := meterProvider.Meter(meterName) + + executionsCounter, err := meter.Int64Counter( + "mcp.code_mode.sandbox.executions", + metric.WithDescription("Code Mode sandbox executions."), + ) + if err != nil { + return nil, err + } + durationHistogram, err := meter.Float64Histogram( + "mcp.code_mode.sandbox.duration", + metric.WithDescription("Code Mode sandbox execution duration."), + metric.WithUnit("ms"), + ) + if err != nil { + return nil, err + } + + return &Meter{ + executionsCounter: executionsCounter, + durationHistogram: durationHistogram, + }, nil +} + +func (m *Meter) Record(ctx context.Context, toolName, status string, durationMs float64) { + if m == nil { + return + } + attrs := metric.WithAttributes( + attribute.String("mcp.tool", toolName), + attribute.String("mcp.status", status), + ) + m.executionsCounter.Add(ctx, 1, attrs) + m.durationHistogram.Record(ctx, durationMs, attrs) +} diff --git a/router/internal/codemode/observability/metrics_test.go b/router/internal/codemode/observability/metrics_test.go new file mode 100644 index 0000000000..e39a1c6021 --- /dev/null +++ b/router/internal/codemode/observability/metrics_test.go @@ -0,0 +1,76 @@ +package observability + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" +) + +func TestMeterRecordEmitsCounterAndDurationHistogram(t *testing.T) { + reader := sdkmetric.NewManualReader() + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + meter, err := NewMeter(provider) + require.NoError(t, err) + + meter.Record(context.Background(), "code_mode_run_js", "success", 12.5) + + var got metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &got)) + counter, histogram := codeModeMetrics(t, got) + + counterData, ok := counter.Data.(metricdata.Sum[int64]) + require.True(t, ok) + require.Len(t, counterData.DataPoints, 1) + counterPoint := counterData.DataPoints[0] + counterPoint.StartTime = time.Time{} + counterPoint.Time = time.Time{} + assert.Equal(t, metricdata.DataPoint[int64]{ + Attributes: attribute.NewSet( + attribute.String("mcp.tool", "code_mode_run_js"), + attribute.String("mcp.status", "success"), + ), + Value: 1, + }, counterPoint) + + histogramData, ok := histogram.Data.(metricdata.Histogram[float64]) + require.True(t, ok) + require.Len(t, histogramData.DataPoints, 1) + histogramPoint := histogramData.DataPoints[0] + histogramPoint.StartTime = time.Time{} + histogramPoint.Time = time.Time{} + assert.Equal(t, metricdata.HistogramDataPoint[float64]{ + Attributes: attribute.NewSet( + attribute.String("mcp.tool", "code_mode_run_js"), + attribute.String("mcp.status", "success"), + ), + Count: 1, + Bounds: histogramPoint.Bounds, + BucketCounts: histogramPoint.BucketCounts, + Min: histogramPoint.Min, + Max: histogramPoint.Max, + Sum: 12.5, + }, histogramPoint) +} + +func codeModeMetrics(t *testing.T, metrics metricdata.ResourceMetrics) (metricdata.Metrics, metricdata.Metrics) { + t.Helper() + require.Len(t, metrics.ScopeMetrics, 1) + assert.Equal(t, "wundergraph.cosmo.router.mcp.code_mode", metrics.ScopeMetrics[0].Scope.Name) + + byName := make(map[string]metricdata.Metrics, len(metrics.ScopeMetrics[0].Metrics)) + for _, metric := range metrics.ScopeMetrics[0].Metrics { + byName[metric.Name] = metric + } + + counter, ok := byName["mcp.code_mode.sandbox.executions"] + require.True(t, ok) + histogram, ok := byName["mcp.code_mode.sandbox.duration"] + require.True(t, ok) + return counter, histogram +} diff --git a/router/internal/codemode/observability/tracing.go b/router/internal/codemode/observability/tracing.go new file mode 100644 index 0000000000..70456d8cf9 --- /dev/null +++ b/router/internal/codemode/observability/tracing.go @@ -0,0 +1,36 @@ +package observability + +import ( + "context" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +const tracerName = "wundergraph.cosmo.router.mcp.code_mode" + +func StartToolSpan(ctx context.Context, toolName string) (context.Context, trace.Span) { + return StartToolSpanWithProvider(ctx, otel.GetTracerProvider(), toolName) +} + +func StartToolSpanWithProvider(ctx context.Context, tracerProvider trace.TracerProvider, toolName string) (context.Context, trace.Span) { + if tracerProvider == nil { + tracerProvider = otel.GetTracerProvider() + } + return tracerProvider.Tracer(tracerName).Start(ctx, toolSpanName(toolName), + trace.WithSpanKind(trace.SpanKindServer), + trace.WithAttributes(attribute.String("mcp.tool", toolName)), + ) +} + +func toolSpanName(toolName string) string { + switch toolName { + case "code_mode_search_tools": + return "MCP Code Mode - Search" + case "code_mode_run_js": + return "MCP Code Mode - Execute" + default: + return "MCP Code Mode - " + toolName + } +} diff --git a/router/internal/codemode/observability/tracing_test.go b/router/internal/codemode/observability/tracing_test.go new file mode 100644 index 0000000000..6ce73bf0a9 --- /dev/null +++ b/router/internal/codemode/observability/tracing_test.go @@ -0,0 +1,67 @@ +package observability + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" +) + +func TestStartToolSpanRecordsSearchServerSpan(t *testing.T) { + recorder := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + previous := otel.GetTracerProvider() + otel.SetTracerProvider(provider) + t.Cleanup(func() { otel.SetTracerProvider(previous) }) + + _, span := StartToolSpan(context.Background(), "code_mode_search_tools") + span.End() + + ended := recorder.Ended() + require.Len(t, ended, 1) + stub := tracetest.SpanStubFromReadOnlySpan(ended[0]) + stub.SpanContext = trace.SpanContext{} + stub.StartTime = time.Time{} + stub.EndTime = time.Time{} + stub.Resource = nil + assert.Equal(t, tracetest.SpanStub{ + Name: "MCP Code Mode - Search", + SpanKind: trace.SpanKindServer, + Attributes: []attribute.KeyValue{ + attribute.String("mcp.tool", "code_mode_search_tools"), + }, + InstrumentationLibrary: stub.InstrumentationLibrary, + }, stub) +} + +func TestStartToolSpanRecordsExecuteServerSpan(t *testing.T) { + recorder := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + + ctx, span := StartToolSpanWithProvider(context.Background(), provider, "code_mode_run_js") + require.True(t, trace.SpanFromContext(ctx).SpanContext().IsValid()) + span.End() + + ended := recorder.Ended() + require.Len(t, ended, 1) + stub := tracetest.SpanStubFromReadOnlySpan(ended[0]) + stub.SpanContext = trace.SpanContext{} + stub.StartTime = time.Time{} + stub.EndTime = time.Time{} + stub.Resource = nil + assert.Equal(t, tracetest.SpanStub{ + Name: "MCP Code Mode - Execute", + SpanKind: trace.SpanKindServer, + Attributes: []attribute.KeyValue{ + attribute.String("mcp.tool", "code_mode_run_js"), + }, + InstrumentationLibrary: stub.InstrumentationLibrary, + }, stub) +} diff --git a/router/internal/codemode/sandbox/errors.go b/router/internal/codemode/sandbox/errors.go new file mode 100644 index 0000000000..c22e75e7be --- /dev/null +++ b/router/internal/codemode/sandbox/errors.go @@ -0,0 +1,201 @@ +package sandbox + +import ( + "encoding/json" + "regexp" + "strconv" + "strings" + + "github.com/fastschema/qjs" +) + +func normalizeError(ctx *qjs.Context, errValue *qjs.Value, sourceMap []byte, program string) (*ErrorEnvelope, error) { + global := ctx.Global() + normalizer := global.GetPropertyStr("__codemodeNormalizeErrorJSON") + encoded, err := ctx.Invoke(normalizer, global, errValue) + if err != nil { + return nil, err + } + + var envelope ErrorEnvelope + if err := json.Unmarshal([]byte(encoded.String()), &envelope); err != nil { + return nil, err + } + envelope.Stack = rewriteStack(envelope.Stack, sourceMap, userCodeStartLine(program)) + rewriteCauseStacks(envelope.Cause, sourceMap, program) + return &envelope, nil +} + +var toolsCallRE = regexp.MustCompile(`tools\.([A-Za-z_$][A-Za-z0-9_$]*)\s*\(`) + +func missingToolName(source string, known []string) string { + knownSet := map[string]struct{}{} + for _, name := range known { + knownSet[name] = struct{}{} + } + for _, match := range toolsCallRE.FindAllStringSubmatch(source, -1) { + if len(match) != 2 { + continue + } + if _, ok := knownSet[match[1]]; !ok { + return match[1] + } + } + return "" +} + +func rewriteCauseStacks(err *ErrorEnvelope, sourceMap []byte, program string) { + for err != nil { + err.Stack = rewriteStack(err.Stack, sourceMap, userCodeStartLine(program)) + err = err.Cause + } +} + +var stackLocationRE = regexp.MustCompile(`(?:\w+\.js:)?(\d+):(\d+)`) + +func rewriteStack(stack string, sourceMap []byte, userStartLine int) string { + if len(sourceMap) == 0 || stack == "" { + return stack + } + sm, err := parseSourceMap(sourceMap) + if err != nil { + return stack + } + return stackLocationRE.ReplaceAllStringFunc(stack, func(match string) string { + parts := stackLocationRE.FindStringSubmatch(match) + if len(parts) != 3 { + return match + } + line, err := strconv.Atoi(parts[1]) + if err != nil { + return match + } + col, err := strconv.Atoi(parts[2]) + if err != nil { + return match + } + generatedLine := line - userStartLine + 1 + if generatedLine < 1 { + return match + } + mapped, ok := sm.lookup(generatedLine, col) + if !ok { + return match + } + prefix := strings.TrimSuffix(match, parts[1]+":"+parts[2]) + return prefix + mapped.source + ":" + strconv.Itoa(mapped.line) + ":" + strconv.Itoa(mapped.column) + }) +} + +type sourceMap struct { + lines [][]mapping +} + +type mapping struct { + generatedColumn int + source string + line int + column int +} + +func parseSourceMap(data []byte) (*sourceMap, error) { + var raw struct { + Sources []string `json:"sources"` + Mappings string `json:"mappings"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + sm := &sourceMap{lines: make([][]mapping, 0)} + var sourceIndex, originalLine, originalColumn int + for _, lineMappings := range strings.Split(raw.Mappings, ";") { + var generatedColumn int + line := make([]mapping, 0) + for _, segment := range strings.Split(lineMappings, ",") { + if segment == "" { + continue + } + values, err := decodeVLQSegment(segment) + if err != nil { + return nil, err + } + if len(values) < 4 { + continue + } + generatedColumn += values[0] + sourceIndex += values[1] + originalLine += values[2] + originalColumn += values[3] + if sourceIndex >= 0 && sourceIndex < len(raw.Sources) { + line = append(line, mapping{ + generatedColumn: generatedColumn, + source: raw.Sources[sourceIndex], + line: originalLine + 1, + column: originalColumn + 1, + }) + } + } + sm.lines = append(sm.lines, line) + } + return sm, nil +} + +func (sm *sourceMap) lookup(generatedLine, generatedColumn int) (mapping, bool) { + if generatedLine < 1 || generatedLine > len(sm.lines) { + return mapping{}, false + } + line := sm.lines[generatedLine-1] + if len(line) == 0 { + return mapping{}, false + } + column0 := generatedColumn - 1 + best := line[0] + for _, candidate := range line { + if candidate.generatedColumn > column0 { + break + } + best = candidate + } + return best, true +} + +const vlqBaseShift = 5 +const vlqBase = 1 << vlqBaseShift +const vlqBaseMask = vlqBase - 1 +const vlqContinuationBit = vlqBase + +var base64VLQ = map[rune]int{ + 'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, + 'I': 8, 'J': 9, 'K': 10, 'L': 11, 'M': 12, 'N': 13, 'O': 14, 'P': 15, + 'Q': 16, 'R': 17, 'S': 18, 'T': 19, 'U': 20, 'V': 21, 'W': 22, 'X': 23, + 'Y': 24, 'Z': 25, 'a': 26, 'b': 27, 'c': 28, 'd': 29, 'e': 30, 'f': 31, + 'g': 32, 'h': 33, 'i': 34, 'j': 35, 'k': 36, 'l': 37, 'm': 38, 'n': 39, + 'o': 40, 'p': 41, 'q': 42, 'r': 43, 's': 44, 't': 45, 'u': 46, 'v': 47, + 'w': 48, 'x': 49, 'y': 50, 'z': 51, '0': 52, '1': 53, '2': 54, '3': 55, + '4': 56, '5': 57, '6': 58, '7': 59, '8': 60, '9': 61, '+': 62, '/': 63, +} + +func decodeVLQSegment(segment string) ([]int, error) { + values := make([]int, 0, 4) + var value, shift int + for _, r := range segment { + digit := base64VLQ[r] + continuation := digit&vlqContinuationBit != 0 + digit &= vlqBaseMask + value += digit << shift + if continuation { + shift += vlqBaseShift + continue + } + negative := value&1 == 1 + value >>= 1 + if negative { + value = -value + } + values = append(values, value) + value = 0 + shift = 0 + } + return values, nil +} diff --git a/router/internal/codemode/sandbox/execute.go b/router/internal/codemode/sandbox/execute.go new file mode 100644 index 0000000000..a983730368 --- /dev/null +++ b/router/internal/codemode/sandbox/execute.go @@ -0,0 +1,210 @@ +package sandbox + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + "time" + + "github.com/fastschema/qjs" +) + +func (s *Sandbox) Execute(ctx context.Context, req ExecuteRequest) (execResult ExecuteResult, retErr error) { + if err := s.acquire(ctx); err != nil { + return ExecuteResult{}, err + } + defer s.release() + + // qjs v0.0.6 panics from inside its Eval/Free/Close paths when the underlying + // wazero module is closed by context cancellation (e.g. host call exceeded + // the sandbox wall-clock). Recover here so a panicking call cannot crash the + // router goroutine; surface as a Timeout envelope instead. + defer func() { + if r := recover(); r != nil { + errEnv := &ErrorEnvelope{Name: "Timeout", Message: fmt.Sprintf("sandbox runtime panic: %v", r)} + if ctx.Err() != nil { + errEnv.Message = ctx.Err().Error() + } + execResult = ExecuteResult{OK: false, Error: errEnv, OutputSize: envelopeSize(nil, errEnv)} + retErr = nil + } + }() + + program := buildPreamble(req.WrappedJS) + if len(program) > s.cfg.MaxInputSizeBytes { + errEnv := &ErrorEnvelope{ + Name: "InputTooLarge", + Message: fmt.Sprintf("input size %d bytes exceeds limit %d bytes", len(program), s.cfg.MaxInputSizeBytes), + Stack: "", + } + return ExecuteResult{OK: false, Error: errEnv, OutputSize: envelopeSize(nil, errEnv)}, nil + } + + execCtx, cancel := context.WithTimeout(ctx, s.cfg.RequestTimeout) + defer cancel() + + rt, err := qjs.New(qjs.Option{ + Context: execCtx, + CloseOnContextDone: true, + DisableBuildCache: true, + MemoryLimit: s.cfg.MemoryLimitBytes, + MaxExecutionTime: int(s.cfg.RequestTimeout / time.Millisecond), + Stdout: io.Discard, + Stderr: io.Discard, + }) + if err != nil { + return runtimeErrorResult(err, execCtx, 0), nil + } + + qctx := rt.Context() + state := &executeState{req: req} + defer func() { + state.wg.Wait() + // qjs panics on Close when the runtime context has already been cancelled. + // Treat the runtime as best-effort cleanup; a leaked WASM instance is bounded + // by GC and the per-call freshness contract. + defer func() { _ = recover() }() + rt.Close() + }() + s.installHostInvoke(execCtx, qctx, state) + if err := installValidationHelpers(qctx); err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + + global := qctx.Global() + toolNames := req.ToolNames + if toolNames == nil { + toolNames = []string{} + } + namesJSON, err := json.Marshal(toolNames) + if err != nil { + return ExecuteResult{}, err + } + names := qctx.ParseJSON(string(namesJSON)) + global.SetPropertyStr("__HOST_TOOL_NAMES", names) + + value, err := qctx.Eval("codemode_agent.js", qjs.Code(program)) + if err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + + value, err = awaitWithContext(execCtx, rt, value) + if err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + okValue := value.GetPropertyStr("ok") + ok := okValue.Bool() + + if !ok { + errValue := value.GetPropertyStr("error") + errEnv, err := normalizeError(qctx, errValue, req.SourceMap, program) + if err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + if errEnv.Name == "InternalError" { + errEnv.Name = "MemoryLimit" + } + if errEnv.Name == "TypeError" && errEnv.Message == "not a function" { + if missing := missingToolName(req.WrappedJS, req.ToolNames); missing != "" { + errEnv.Message = "tools." + missing + " is not a function" + } + } + hostCalls := int(state.hostCalls.Load()) + if errEnv.Name == "HostCallLimitExceeded" { + hostCalls = s.cfg.MaxToolInvocationsPerCall + 1 + } + return ExecuteResult{ + OK: false, + Error: errEnv, + OutputSize: envelopeSize(nil, errEnv), + HostCalls: hostCalls, + }, nil + } + + resultValue := value.GetPropertyStr("result") + result, validationErr, err := validateResult(qctx, resultValue, s.cfg.MaxOutputSizeBytes) + if err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + if validationErr != nil { + return ExecuteResult{ + OK: false, + Error: validationErr, + OutputSize: envelopeSize(nil, validationErr), + HostCalls: int(state.hostCalls.Load()), + }, nil + } + return ExecuteResult{ + OK: true, + Result: result, + OutputSize: envelopeSize(result, nil), + HostCalls: int(state.hostCalls.Load()), + }, nil +} + +type awaitResult struct { + value *qjs.Value + err error +} + +func awaitWithContext(ctx context.Context, rt *qjs.Runtime, value *qjs.Value) (*qjs.Value, error) { + if !value.IsPromise() { + return value, nil + } + + done := make(chan awaitResult, 1) + go func() { + awaited, err := value.Await() + done <- awaitResult{value: awaited, err: err} + }() + + select { + case result := <-done: + return result.value, result.err + case <-ctx.Done(): + // Best-effort runtime close so the await goroutine unblocks; the deferred + // close in Execute owns the canonical cleanup (and recovers any qjs panic). + func() { + defer func() { _ = recover() }() + rt.Close() + }() + select { + case result := <-done: + _ = result + case <-time.After(100 * time.Millisecond): + } + return nil, ctx.Err() + } +} + +func runtimeErrorResult(err error, ctx context.Context, hostCalls int) ExecuteResult { + errEnv := classifyRuntimeError(err, ctx) + return ExecuteResult{ + OK: false, + Error: errEnv, + OutputSize: envelopeSize(nil, errEnv), + HostCalls: hostCalls, + } +} + +func classifyRuntimeError(err error, ctx context.Context) *ErrorEnvelope { + if ctx.Err() != nil { + return &ErrorEnvelope{Name: "Timeout", Message: ctx.Err().Error(), Stack: ""} + } + msg := err.Error() + lower := strings.ToLower(msg) + if strings.Contains(lower, "memory") || strings.Contains(lower, "out of memory") { + return &ErrorEnvelope{Name: "MemoryLimit", Message: msg, Stack: ""} + } + return &ErrorEnvelope{Name: "Error", Message: msg, Stack: ""} +} + +func envelopeSize(result json.RawMessage, errEnv *ErrorEnvelope) int { + if errEnv != nil { + body, _ := json.Marshal(errEnv) + return len(body) + } + return len(result) +} diff --git a/router/internal/codemode/sandbox/headers.go b/router/internal/codemode/sandbox/headers.go new file mode 100644 index 0000000000..100ae0be85 --- /dev/null +++ b/router/internal/codemode/sandbox/headers.go @@ -0,0 +1,44 @@ +package sandbox + +import ( + "net/http" + "strings" +) + +var hopByHopHeaders = map[string]struct{}{ + "connection": {}, + "keep-alive": {}, + "proxy-authenticate": {}, + "proxy-authorization": {}, + "te": {}, + "trailer": {}, + "transfer-encoding": {}, + "upgrade": {}, +} + +func headerAllowList(headers []string) map[string]struct{} { + allow := make(map[string]struct{}, len(headers)) + for _, h := range headers { + canonical := strings.ToLower(http.CanonicalHeaderKey(h)) + if _, hop := hopByHopHeaders[canonical]; hop { + continue + } + allow[canonical] = struct{}{} + } + return allow +} + +func copyAllowedHeaders(dst, src http.Header, allow map[string]struct{}) { + for name, values := range src { + canonical := strings.ToLower(http.CanonicalHeaderKey(name)) + if _, hop := hopByHopHeaders[canonical]; hop { + continue + } + if _, ok := allow[canonical]; !ok { + continue + } + for _, value := range values { + dst.Add(name, value) + } + } +} diff --git a/router/internal/codemode/sandbox/host.go b/router/internal/codemode/sandbox/host.go new file mode 100644 index 0000000000..c0fce81cc8 --- /dev/null +++ b/router/internal/codemode/sandbox/host.go @@ -0,0 +1,242 @@ +package sandbox + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + + "github.com/fastschema/qjs" + "github.com/wundergraph/cosmo/router/internal/codemode/observability" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +// TODO(code-mode §9): The plan calls for channel-based async host calls so +// Promise.all can overlap HTTP work. qjs v0.0.6 SetAsyncFunc invokes the Go +// callback synchronously on the QuickJS/Wazero call path, and resolving from +// arbitrary goroutines is not supported by the wrapper without a JS-thread +// drain loop, so host calls remain serialized for the MVP. + +type executeState struct { + req ExecuteRequest + hostCalls atomic.Int32 + qjsMu sync.Mutex + wg sync.WaitGroup +} + +func (s *Sandbox) installHostInvoke(ctx context.Context, qctx *qjs.Context, state *executeState) { + qctx.SetAsyncFunc("__hostInvokeTool", func(this *qjs.This) { + args := this.Args() + name := "" + if len(args) > 0 && !args[0].IsUndefined() && !args[0].IsNull() { + name = args[0].String() + } + vars, err := varsJSON(args) + if err != nil { + resolveString(this, state, hostErrorPayload("TypeError", err.Error())) + return + } + + result, invokeErr := s.invokeTool(ctx, state, name, vars) + if invokeErr != nil { + resolveString(this, state, hostErrorPayload(invokeErr.name, invokeErr.message)) + return + } + resolveString(this, state, string(result)) + }) +} + +func resolveString(this *qjs.This, state *executeState, payload string) { + state.qjsMu.Lock() + defer state.qjsMu.Unlock() + this.Promise().Resolve(this.Context().NewString(payload)) +} + +func hostErrorPayload(name, message string) string { + body, _ := json.Marshal(map[string]any{ + "__codemodeHostError": map[string]string{ + "name": name, + "message": message, + }, + }) + return string(body) +} + +type hostError struct { + name string + message string +} + +func varsJSON(args []*qjs.Value) (json.RawMessage, error) { + if len(args) < 2 || args[1].IsUndefined() || args[1].IsNull() { + return json.RawMessage(`{}`), nil + } + jsonString, err := args[1].JSONStringify() + if err != nil { + return nil, err + } + if jsonString == "" || jsonString == "null" { + return json.RawMessage(`{}`), nil + } + return json.RawMessage(jsonString), nil +} + +func (s *Sandbox) invokeTool(ctx context.Context, state *executeState, name string, vars json.RawMessage) (json.RawMessage, *hostError) { + count := int(state.hostCalls.Add(1)) + if count > s.cfg.MaxToolInvocationsPerCall { + return nil, &hostError{ + name: "HostCallLimitExceeded", + message: fmt.Sprintf("tools.* invocation cap of %d exceeded; batch independent calls with Promise.all.", s.cfg.MaxToolInvocationsPerCall), + } + } + + op, ok, err := s.cfg.StorageLookup(ctx, state.req.SessionID, name) + if err != nil { + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "Error", message: err.Error()} + } + if !ok { + err := fmt.Errorf("tools.%s is not a function", name) + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "TypeError", message: err.Error()} + } + + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("codemode.op.name", op.Name), + attribute.String("codemode.op.kind", string(op.Kind)), + ) + + if op.Kind == storage.OperationKindMutation { + gate := state.req.ApprovalGate + if gate == nil { + gate = approveAllGate{} + } + decision, err := gate.Decide(ctx, ApprovalRequest{Name: name, Source: op.Body, Vars: vars}) + if err != nil { + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "Error", message: err.Error()} + } + span.SetAttributes( + attribute.Bool("code_mode.mutation.approved", decision.Approved), + attribute.String("code_mode.mutation.reason", decision.Reason), + ) + if !decision.Approved { + body := mutationDeclinedResponse(decision.Reason) + span.SetAttributes(attribute.Bool("codemode.op.success", false)) + return body, nil + } + } + + body, err := json.Marshal(graphQLRequest{ + Query: op.Body, + OperationName: name, + Variables: vars, + }) + if err != nil { + return nil, &hostError{name: "Error", message: err.Error()} + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, s.cfg.RouterGraphQLEndpoint, bytes.NewReader(body)) + if err != nil { + return nil, &hostError{name: "Error", message: err.Error()} + } + copyAllowedHeaders(httpReq.Header, state.req.RequestHeaders, s.allowList) + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := s.http.Do(httpReq) + if err != nil { + span.SetAttributes(attribute.Bool("codemode.op.success", false)) + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "Error", message: err.Error()} + } + defer resp.Body.Close() + + respBody, err := readCapped(resp.Body, s.cfg.MaxResponseBodyBytes) + if err != nil { + span.SetAttributes(attribute.Bool("codemode.op.success", false)) + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "Error", message: err.Error()} + } + + result := normalizeGraphQLResponse(resp.StatusCode, respBody) + if errorsJSON := graphQLErrors(result); errorsJSON != "" { + span.SetAttributes(attribute.String("codemode.graphql.errors", errorsJSON)) + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, fmt.Errorf("graphql errors: %s", errorsJSON)) + } + span.SetAttributes(attribute.Bool("codemode.op.success", resp.StatusCode < 400)) + if resp.StatusCode >= 400 { + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, fmt.Errorf("graphql http status %d", resp.StatusCode)) + } + return result, nil +} + +type graphQLRequest struct { + Query string `json:"query"` + OperationName string `json:"operationName"` + Variables json.RawMessage `json:"variables"` +} + +func mutationDeclinedResponse(reason string) json.RawMessage { + if reason == "" { + reason = "Mutation declined by operator" + } + body, _ := json.Marshal(map[string]any{ + "data": nil, + "errors": []map[string]string{{ + "message": "Mutation declined by operator: " + reason, + }}, + "declined": map[string]string{"reason": reason}, + }) + return body +} + +func normalizeGraphQLResponse(status int, body []byte) json.RawMessage { + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err == nil { + if status >= 400 { + if _, ok := payload["errors"]; ok { + out, _ := json.Marshal(payload) + return out + } + } + out, _ := json.Marshal(payload) + return out + } + msg := strings.TrimSpace(string(body)) + if msg == "" { + msg = http.StatusText(status) + } + out, _ := json.Marshal(map[string]any{ + "errors": []map[string]string{{"message": msg}}, + }) + return out +} + +func graphQLErrors(body json.RawMessage) string { + var payload struct { + Errors json.RawMessage `json:"errors"` + } + if err := json.Unmarshal(body, &payload); err != nil || len(payload.Errors) == 0 { + return "" + } + return string(payload.Errors) +} + +func readCapped(r io.Reader, capBytes int) ([]byte, error) { + data, err := io.ReadAll(io.LimitReader(r, int64(capBytes)+1)) + if err != nil { + return nil, err + } + if len(data) > capBytes { + return nil, fmt.Errorf("tools.* HTTP response body exceeded %d bytes", capBytes) + } + return data, nil +} diff --git a/router/internal/codemode/sandbox/preamble.go b/router/internal/codemode/sandbox/preamble.go new file mode 100644 index 0000000000..58b0bf57ef --- /dev/null +++ b/router/internal/codemode/sandbox/preamble.go @@ -0,0 +1,28 @@ +package sandbox + +import ( + _ "embed" + "strings" +) + +//go:embed sandbox_preamble.js +var preambleTemplate string + +const ( + spliceComment = "// Splice point: Execute.WrappedJS is already harness-wrapped and transpiled." + agentMainSpliceID = "__AGENT_MAIN_SPLICE__" +) + +func buildPreamble(wrappedJS string) string { + return strings.Replace(preambleTemplate, agentMainSpliceID, wrappedJS, 1) +} + +func userCodeStartLine(program string) int { + lines := strings.Split(program, "\n") + for i, line := range lines { + if line == spliceComment { + return i + 2 + } + } + return 1 +} diff --git a/router/internal/codemode/sandbox/preamble_test.go b/router/internal/codemode/sandbox/preamble_test.go new file mode 100644 index 0000000000..e00fd9d562 --- /dev/null +++ b/router/internal/codemode/sandbox/preamble_test.go @@ -0,0 +1,94 @@ +package sandbox + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuildPreambleGolden(t *testing.T) { + got := buildPreamble("async () => ({ ok: true })") + + want := `"use strict"; + +const tools = {}; +for (const name of __HOST_TOOL_NAMES) { + tools[name] = async (vars) => { + const __hostPayload = await __hostInvokeTool(name, vars); + const __hostResult = JSON.parse(__hostPayload); + if (__hostResult?.__codemodeHostError) { + const e = new Error(__hostResult.__codemodeHostError.message); + e.name = __hostResult.__codemodeHostError.name; + throw e; + } + return __hostResult; + }; +} +Object.freeze(tools); +globalThis.tools = tools; + +const __consoleErr = () => { + const e = new Error( + "console is not available in this sandbox. " + + "Include diagnostics in your return value, e.g. ` + "`return { result, debug: { ... } }`" + `." + ); + e.name = "ConsoleUnavailable"; + throw e; +}; +globalThis.console = new Proxy({}, { + get: () => __consoleErr, +}); + +Math.random = () => 0; +Date.now = () => 0; + +const __OrigDate = Date; +const __PinnedDate = function Date(...args) { + return args.length === 0 ? new __OrigDate(0) : new __OrigDate(...args); +}; +Object.setPrototypeOf(__PinnedDate, __OrigDate); +__PinnedDate.prototype = __OrigDate.prototype; +__PinnedDate.now = () => 0; +__PinnedDate.UTC = __OrigDate.UTC; +__PinnedDate.parse = __OrigDate.parse; +globalThis.Date = __PinnedDate; + +globalThis.notNull = (v, msg) => { + if (v == null) throw new Error(msg ?? "notNull: value was null/undefined"); + return v; +}; +globalThis.compact = (v) => { + if (Array.isArray(v)) return v.map(compact).filter((x) => x != null); + if (v && typeof v === "object") { + const out = {}; + for (const k in v) { + const c = compact(v[k]); + if (c != null) out[k] = c; + } + return out; + } + return v; +}; + +delete globalThis.eval; +delete globalThis.Function; +// Also remove indirect access via the Function constructor on the function prototype. +// (Function.prototype.constructor still exists per JS spec, but with eval/Function deleted +// it no longer resolves to a usable constructor.) + +// Splice point: Execute.WrappedJS is already harness-wrapped and transpiled. +const __agentMain = (async () => ({ ok: true })); +(async () => { + try { return { ok: true, result: await __agentMain() }; } + catch (err) { + return { ok: false, error: { name: err?.name ?? "Error", message: err?.message ?? String(err), stack: err?.stack ?? "", cause: err?.cause } }; + } +})() +` + assert.Equal(t, want, got) +} + +func TestBuildPreambleReportsUserCodeStartLine(t *testing.T) { + got := userCodeStartLine(buildPreamble("async () => 1")) + assert.Equal(t, 69, got) +} diff --git a/router/internal/codemode/sandbox/sandbox.go b/router/internal/codemode/sandbox/sandbox.go new file mode 100644 index 0000000000..a4d42f91f3 --- /dev/null +++ b/router/internal/codemode/sandbox/sandbox.go @@ -0,0 +1,168 @@ +package sandbox + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/hashicorp/go-retryablehttp" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.uber.org/zap" +) + +const ( + defaultRequestTimeout = 5 * time.Second + defaultMemoryLimitBytes = 16 << 20 + defaultMaxInputSizeBytes = 64 << 10 + defaultMaxOutputSizeBytes = 1 << 20 + defaultMaxResultBytes = 32 << 10 + defaultMaxToolInvocationsPerCall = 256 + defaultMaxResponseBodyBytes = 10 << 20 + defaultRetryAttempts = 3 + defaultRetryCeiling = 60 * time.Second + defaultMaxConcurrent = 4 +) + +type Sandbox struct { + cfg Config + sem chan struct{} + http *http.Client + allowList map[string]struct{} +} + +type Config struct { + RouterGraphQLEndpoint string + RequestTimeout time.Duration + MemoryLimitBytes int + MaxInputSizeBytes int + MaxOutputSizeBytes int + MaxResultBytes int + MaxToolInvocationsPerCall int + MaxResponseBodyBytes int + RetryAttempts int + RetryCeiling time.Duration + MaxConcurrent int + HeaderAllowList []string + StorageLookup func(ctx context.Context, sessionID string, name string) (storage.SessionOp, bool, error) + Logger *zap.Logger + Now func() time.Time + HTTPClient *http.Client +} + +type ExecuteRequest struct { + SessionID string + ToolNames []string + WrappedJS string + SourceMap []byte + RequestHeaders http.Header + ApprovalGate ApprovalGate +} + +type ExecuteResult struct { + OK bool + Result json.RawMessage + Error *ErrorEnvelope + Truncated bool + OutputSize int + HostCalls int +} + +type ErrorEnvelope struct { + Name string `json:"name"` + Message string `json:"message"` + Stack string `json:"stack"` + Cause *ErrorEnvelope `json:"cause,omitempty"` +} + +type ApprovalGate interface { + Decide(ctx context.Context, req ApprovalRequest) (ApprovalDecision, error) +} + +type ApprovalRequest struct { + Name string + Source string + Vars json.RawMessage +} + +type ApprovalDecision struct { + Approved bool + Reason string +} + +type approveAllGate struct{} + +var AutoApprove ApprovalGate = approveAllGate{} + +func (approveAllGate) Decide(context.Context, ApprovalRequest) (ApprovalDecision, error) { + return ApprovalDecision{Approved: true}, nil +} + +func New(cfg Config) (*Sandbox, error) { + cfg = withDefaults(cfg) + if cfg.MaxConcurrent <= 0 { + return nil, errors.New("sandbox max concurrent must be positive") + } + if cfg.StorageLookup == nil { + cfg.StorageLookup = func(context.Context, string, string) (storage.SessionOp, bool, error) { + return storage.SessionOp{}, false, nil + } + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + if cfg.Now == nil { + cfg.Now = time.Now + } + + client := cfg.HTTPClient + if client == nil { + retryClient := retryablehttp.NewClient() + retryClient.RetryMax = cfg.RetryAttempts + retryClient.RetryWaitMax = cfg.RetryCeiling + retryClient.Logger = nil + client = retryClient.StandardClient() + } + + return &Sandbox{ + cfg: cfg, + sem: make(chan struct{}, cfg.MaxConcurrent), + http: client, + allowList: headerAllowList(cfg.HeaderAllowList), + }, nil +} + +func withDefaults(cfg Config) Config { + if cfg.RequestTimeout <= 0 { + cfg.RequestTimeout = defaultRequestTimeout + } + if cfg.MemoryLimitBytes <= 0 { + cfg.MemoryLimitBytes = defaultMemoryLimitBytes + } + if cfg.MaxInputSizeBytes <= 0 { + cfg.MaxInputSizeBytes = defaultMaxInputSizeBytes + } + if cfg.MaxOutputSizeBytes <= 0 { + cfg.MaxOutputSizeBytes = defaultMaxOutputSizeBytes + } + if cfg.MaxResultBytes <= 0 { + cfg.MaxResultBytes = defaultMaxResultBytes + } + if cfg.MaxToolInvocationsPerCall <= 0 { + cfg.MaxToolInvocationsPerCall = defaultMaxToolInvocationsPerCall + } + if cfg.MaxResponseBodyBytes <= 0 { + cfg.MaxResponseBodyBytes = defaultMaxResponseBodyBytes + } + if cfg.RetryAttempts <= 0 { + cfg.RetryAttempts = defaultRetryAttempts + } + if cfg.RetryCeiling <= 0 { + cfg.RetryCeiling = defaultRetryCeiling + } + if cfg.MaxConcurrent <= 0 { + cfg.MaxConcurrent = defaultMaxConcurrent + } + return cfg +} diff --git a/router/internal/codemode/sandbox/sandbox_preamble.js b/router/internal/codemode/sandbox/sandbox_preamble.js new file mode 100644 index 0000000000..32ee04e1a4 --- /dev/null +++ b/router/internal/codemode/sandbox/sandbox_preamble.js @@ -0,0 +1,75 @@ +"use strict"; + +const tools = {}; +for (const name of __HOST_TOOL_NAMES) { + tools[name] = async (vars) => { + const __hostPayload = await __hostInvokeTool(name, vars); + const __hostResult = JSON.parse(__hostPayload); + if (__hostResult?.__codemodeHostError) { + const e = new Error(__hostResult.__codemodeHostError.message); + e.name = __hostResult.__codemodeHostError.name; + throw e; + } + return __hostResult; + }; +} +Object.freeze(tools); +globalThis.tools = tools; + +const __consoleErr = () => { + const e = new Error( + "console is not available in this sandbox. " + + "Include diagnostics in your return value, e.g. `return { result, debug: { ... } }`." + ); + e.name = "ConsoleUnavailable"; + throw e; +}; +globalThis.console = new Proxy({}, { + get: () => __consoleErr, +}); + +Math.random = () => 0; +Date.now = () => 0; + +const __OrigDate = Date; +const __PinnedDate = function Date(...args) { + return args.length === 0 ? new __OrigDate(0) : new __OrigDate(...args); +}; +Object.setPrototypeOf(__PinnedDate, __OrigDate); +__PinnedDate.prototype = __OrigDate.prototype; +__PinnedDate.now = () => 0; +__PinnedDate.UTC = __OrigDate.UTC; +__PinnedDate.parse = __OrigDate.parse; +globalThis.Date = __PinnedDate; + +globalThis.notNull = (v, msg) => { + if (v == null) throw new Error(msg ?? "notNull: value was null/undefined"); + return v; +}; +globalThis.compact = (v) => { + if (Array.isArray(v)) return v.map(compact).filter((x) => x != null); + if (v && typeof v === "object") { + const out = {}; + for (const k in v) { + const c = compact(v[k]); + if (c != null) out[k] = c; + } + return out; + } + return v; +}; + +delete globalThis.eval; +delete globalThis.Function; +// Also remove indirect access via the Function constructor on the function prototype. +// (Function.prototype.constructor still exists per JS spec, but with eval/Function deleted +// it no longer resolves to a usable constructor.) + +// Splice point: Execute.WrappedJS is already harness-wrapped and transpiled. +const __agentMain = (__AGENT_MAIN_SPLICE__); +(async () => { + try { return { ok: true, result: await __agentMain() }; } + catch (err) { + return { ok: false, error: { name: err?.name ?? "Error", message: err?.message ?? String(err), stack: err?.stack ?? "", cause: err?.cause } }; + } +})() diff --git a/router/internal/codemode/sandbox/sandbox_test.go b/router/internal/codemode/sandbox/sandbox_test.go new file mode 100644 index 0000000000..db69e82679 --- /dev/null +++ b/router/internal/codemode/sandbox/sandbox_test.go @@ -0,0 +1,648 @@ +package sandbox + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/evanw/esbuild/pkg/api" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +type DeclinedGate struct { + reason string +} + +func (g DeclinedGate) Decide(context.Context, ApprovalRequest) (ApprovalDecision, error) { + return ApprovalDecision{Approved: false, Reason: g.reason}, nil +} + +type nameDeclinedGate struct { + name string + reason string +} + +func (g nameDeclinedGate) Decide(_ context.Context, req ApprovalRequest) (ApprovalDecision, error) { + if req.Name == g.name { + return ApprovalDecision{Approved: false, Reason: g.reason}, nil + } + return ApprovalDecision{Approved: true}, nil +} + +type lookup map[string]storage.SessionOp + +func (l lookup) get(_ context.Context, _ string, name string) (storage.SessionOp, bool, error) { + op, ok := l[name] + return op, ok, nil +} + +func clientFunc(fn roundTripFunc) *http.Client { + return &http.Client{Transport: fn} +} + +func jsonResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: ioNopCloser{bytes.NewBufferString(body)}, + } +} + +func newTestSandbox(t *testing.T, endpoint string, ops lookup, opts func(*Config)) *Sandbox { + t.Helper() + + cfg := Config{ + RouterGraphQLEndpoint: endpoint, + StorageLookup: ops.get, + RequestTimeout: 30 * time.Second, + RetryAttempts: 0, + } + if opts != nil { + opts(&cfg) + } + s, err := New(cfg) + require.NoError(t, err) + return s +} + +func execute(t *testing.T, s *Sandbox, req ExecuteRequest) ExecuteResult { + t.Helper() + + got, err := s.Execute(context.Background(), req) + require.NoError(t, err) + return got +} + +func raw(s string) json.RawMessage { + return json.RawMessage(s) +} + +func TestExecuteHappyPathToolCall(t *testing.T) { + var gotBody map[string]any + client := clientFunc(func(r *http.Request) (*http.Response, error) { + require.Equal(t, http.MethodPost, r.Method) + require.NoError(t, json.NewDecoder(r.Body).Decode(&gotBody)) + return jsonResponse(http.StatusOK, `{"data":{"order":{"id":"o1"}}}`), nil + }) + + s := newTestSandbox(t, "http://router/graphql", lookup{ + "getOrder": {Name: "getOrder", Body: "query GetOrder($id: ID!) { order(id: $id) { id } }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getOrder"}, + WrappedJS: `async () => { + return await tools.getOrder({ id: "o1" }); +}`, + }) + + assert.Equal(t, "getOrder", gotBody["operationName"]) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":{"order":{"id":"o1"}}}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) +} + +func TestExecuteGraphQLErrorsResolveVerbatimAndRecordSpan(t *testing.T) { + client := clientFunc(func(r *http.Request) (*http.Response, error) { + return jsonResponse(http.StatusOK, `{"data":null,"errors":[{"message":"x"}]}`), nil + }) + + exporter := tracetest.NewInMemoryExporter() + tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) + old := otel.GetTracerProvider() + otel.SetTracerProvider(tp) + defer otel.SetTracerProvider(old) + + s := newTestSandbox(t, "http://router/graphql", lookup{ + "getBroken": {Name: "getBroken", Body: "query Broken { broken }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + ctx, span := otel.Tracer("sandbox-test").Start(context.Background(), "parent") + got, err := s.Execute(ctx, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getBroken"}, + WrappedJS: `async () => await tools.getBroken()`, + }) + span.End() + require.NoError(t, err) + + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":null,"errors":[{"message":"x"}]}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) + spans := exporter.GetSpans() + require.NotEmpty(t, spans) + var found bool + for _, sp := range spans { + for _, attr := range sp.Attributes { + if string(attr.Key) == "codemode.graphql.errors" && strings.Contains(attr.Value.AsString(), `"message":"x"`) { + found = true + } + } + } + assert.Equal(t, true, found) +} + +func TestExecuteHTTP500CanBeReturnedOrThrownByAgent(t *testing.T) { + client := clientFunc(func(r *http.Request) (*http.Response, error) { + return jsonResponse(http.StatusInternalServerError, `{"errors":[{"message":"upstream failed"}]}`), nil + }) + + s := newTestSandbox(t, "http://router/graphql", lookup{ + "getBroken": {Name: "getBroken", Body: "query Broken { broken }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + returned := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getBroken"}, + WrappedJS: `async () => await tools.getBroken()`, + }) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"errors":[{"message":"upstream failed"}]}`), + HostCalls: 1, + }, ExecuteResult{OK: returned.OK, Result: returned.Result, HostCalls: returned.HostCalls}) + + thrown := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getBroken"}, + WrappedJS: `async () => { + const r = await tools.getBroken(); + if (r.errors?.length) throw new Error(r.errors[0].message); + return r; +}`, + }) + assert.Equal(t, false, thrown.OK) + require.NotNil(t, thrown.Error) + assert.Equal(t, "Error", thrown.Error.Name) + assert.Equal(t, "upstream failed", thrown.Error.Message) + assert.Equal(t, 1, thrown.HostCalls) +} + +func TestExecuteConsoleUnavailable(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { console.log("x"); }`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, ErrorEnvelope{ + Name: "ConsoleUnavailable", + Message: "console is not available in this sandbox. Include diagnostics in your return value, e.g. `return { result, debug: { ... } }`.", + Stack: got.Error.Stack, + }, *got.Error) +} + +func TestExecuteEvalAndFunctionRemoved(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + tests := []struct { + name string + wrappedJS string + want json.RawMessage + }{ + { + name: "typeof eval", + wrappedJS: `async () => { return typeof eval; }`, + want: raw(`"undefined"`), + }, + { + name: "typeof Function", + wrappedJS: `async () => { return typeof Function; }`, + want: raw(`"undefined"`), + }, + { + name: "indirect eval", + wrappedJS: `async () => { try { (0, eval)("1+1"); return "ok"; } catch (e) { return e.name + ":" + e.message; } }`, + want: raw(`"ReferenceError:eval is not defined"`), + }, + { + name: "new Function", + wrappedJS: `async () => { try { new Function("return 1"); return "ok"; } catch (e) { return e.name + ":" + e.message; } }`, + want: raw(`"ReferenceError:Function is not defined"`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := execute(t, s, ExecuteRequest{WrappedJS: tt.wrappedJS}) + + assert.Equal(t, ExecuteResult{OK: true, Result: tt.want}, ExecuteResult{OK: got.OK, Result: got.Result}) + }) + } +} + +func TestExecuteDeterministicDateAndRandom(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => ({ + random: Math.random(), + now: Date.now(), + epoch: new Date().getTime(), + parsed: new Date(123).getTime() +})`}) + + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"random":0,"now":0,"epoch":0,"parsed":123}`), + }, ExecuteResult{OK: got.OK, Result: got.Result}) +} + +func TestExecuteAllowsConfiguredHostCallCapAndThrowsOnNextCall(t *testing.T) { + var calls atomic.Int32 + client := clientFunc(func(r *http.Request) (*http.Response, error) { + calls.Add(1) + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "foo": {Name: "foo", Body: "query Foo { foo }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + withinCap := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"foo"}, + WrappedJS: `async () => { + for (let i = 0; i < 256; i++) await tools.foo({}); + return "ok"; +}`, + }) + + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`"ok"`), + HostCalls: 256, + }, ExecuteResult{OK: withinCap.OK, Result: withinCap.Result, HostCalls: withinCap.HostCalls}) + assert.Equal(t, int32(256), calls.Load()) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"foo"}, + WrappedJS: `async () => { + for (let i = 0; i < 257; i++) await tools.foo({}); + return null; +}`, + }) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "HostCallLimitExceeded", got.Error.Name) + assert.Equal(t, "tools.* invocation cap of 256 exceeded; batch independent calls with Promise.all.", got.Error.Message) + assert.Equal(t, 257, got.HostCalls) + assert.Equal(t, int32(512), calls.Load()) +} + +func TestExecutePromiseAllToolCallsRunInParallel(t *testing.T) { + var calls atomic.Int32 + client := clientFunc(func(r *http.Request) (*http.Response, error) { + calls.Add(1) + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "ping": {Name: "ping", Body: "query Ping { ping }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"ping"}, + WrappedJS: `async () => Promise.all([tools.ping(), tools.ping(), tools.ping(), tools.ping()])`, + }) + + assert.Equal(t, true, got.OK) + assert.Equal(t, 4, got.HostCalls) + assert.Equal(t, int32(4), calls.Load()) +} + +func TestExecuteAcceptsTopLevelAwaitStringAsHarnessDeviation(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => await Promise.resolve(1)`}) + + assert.Equal(t, ExecuteResult{OK: true, Result: raw(`1`)}, ExecuteResult{OK: got.OK, Result: got.Result}) +} + +func TestExecuteWallClockTimeout(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, func(cfg *Config) { + cfg.RequestTimeout = 25 * time.Millisecond + }) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => await new Promise(() => {})`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "Timeout", got.Error.Name) +} + +func TestExecuteMemoryLimit(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, func(cfg *Config) { + cfg.MemoryLimitBytes = 2 << 20 + }) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { + const xs = []; + for (let i = 0; i < 1000000; i++) xs.push("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); + return xs.length; +}`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "MemoryLimit", got.Error.Name) +} + +func TestExecuteNotSerializableResult(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => ({ x: () => 1 })`}) + + assert.Equal(t, false, got.OK) + assert.Equal(t, json.RawMessage(nil), got.Result) + require.NotNil(t, got.Error) + assert.Equal(t, ErrorEnvelope{ + Name: "NotSerializable", + Message: "return value contains non-JSON-serializable values at $.x", + Stack: "", + }, *got.Error) +} + +func TestExecuteNotSerializableProducesErrorEnvelope(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { return { x: () => 1, y: 5n, cycle: (() => { const o = {}; o.self = o; return o; })() }; }`}) + + assert.Equal(t, false, got.OK) + assert.Equal(t, json.RawMessage(nil), got.Result) + require.NotNil(t, got.Error) + assert.Equal(t, ErrorEnvelope{ + Name: "NotSerializable", + Message: "return value contains non-JSON-serializable values at $.x, $.y, $.cycle.self", + Stack: "", + }, *got.Error) +} + +func TestExecuteOutputTooLarge(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, func(cfg *Config) { + cfg.MaxOutputSizeBytes = 10 + }) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => "this is too large"`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "OutputTooLarge", got.Error.Name) + assert.Contains(t, got.Error.Message, "encoded result size") +} + +func TestExecuteErrorCauseChain(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { + throw new Error("a", { cause: new Error("b", { cause: new Error("c") }) }); +}`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "a", got.Error.Message) + require.NotNil(t, got.Error.Cause) + assert.Equal(t, "b", got.Error.Cause.Message) + require.NotNil(t, got.Error.Cause.Cause) + assert.Equal(t, "c", got.Error.Cause.Cause.Message) + assert.Nil(t, got.Error.Cause.Cause.Cause) +} + +func TestExecuteErrorCauseChainTruncatesAfterDepthFive(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { + let err = new Error("7"); + for (let i = 6; i >= 1; i--) err = new Error(String(i), { cause: err }); + throw err; +}`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + cause := got.Error + for range 5 { + require.NotNil(t, cause.Cause) + cause = cause.Cause + } + assert.Equal(t, "TruncatedCause", cause.Name) + assert.Equal(t, "cause chain exceeded depth 5", cause.Message) +} + +func TestExecuteSourceMapRewrite(t *testing.T) { + ts := "async () => {\n const x: number = 1;\n throw new Error(\"boom\");\n}" + transformed := api.Transform(ts, api.TransformOptions{ + Loader: api.LoaderTS, + Sourcemap: api.SourceMapExternal, + Sourcefile: "agent.ts", + }) + require.Empty(t, transformed.Errors) + js := strings.TrimSpace(string(transformed.Code)) + js = strings.TrimSuffix(js, ";") + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: js, SourceMap: []byte(transformed.Map)}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Contains(t, got.Error.Stack, "agent.ts:3:") +} + +func TestExecuteMutationApprovalDeclined(t *testing.T) { + var calls atomic.Int32 + client := clientFunc(func(r *http.Request) (*http.Response, error) { + calls.Add(1) + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "deleteOrder": {Name: "deleteOrder", Body: "mutation DeleteOrder { deleteOrder }", Kind: storage.OperationKindMutation}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"deleteOrder"}, + ApprovalGate: DeclinedGate{reason: "no thanks"}, + WrappedJS: `async () => await tools.deleteOrder({ id: "o1" })`, + RequestHeaders: http.Header{}, + }) + + assert.Equal(t, int32(0), calls.Load()) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":null,"declined":{"reason":"no thanks"},"errors":[{"message":"Mutation declined by operator: no thanks"}]}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) +} + +func TestExecuteSpecificMutationApprovalDeclinedReturnsStructuredValue(t *testing.T) { + var calls atomic.Int32 + client := clientFunc(func(r *http.Request) (*http.Response, error) { + calls.Add(1) + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "deleteOrders": {Name: "deleteOrders", Body: "mutation DeleteOrders($id: ID!) { deleteOrders(id: $id) }", Kind: storage.OperationKindMutation}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"deleteOrders"}, + ApprovalGate: nameDeclinedGate{name: "deleteOrders", reason: "policy forbids"}, + WrappedJS: `async () => { const r = await tools.deleteOrders({id:"x"}); return r; }`, + RequestHeaders: http.Header{}, + }) + + assert.Equal(t, int32(0), calls.Load()) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":null,"declined":{"reason":"policy forbids"},"errors":[{"message":"Mutation declined by operator: policy forbids"}]}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) +} + +func TestExecuteHeaderAllowList(t *testing.T) { + seen := make(chan http.Header, 1) + client := clientFunc(func(r *http.Request) (*http.Response, error) { + seen <- r.Header.Clone() + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "ping": {Name: "ping", Body: "query Ping { ping }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { + cfg.HeaderAllowList = []string{"Authorization", "X-Trace"} + cfg.HTTPClient = client + }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"ping"}, + WrappedJS: `async () => await tools.ping()`, + RequestHeaders: http.Header{ + "Authorization": []string{"Bearer token"}, + "X-Trace": []string{"trace-1"}, + "X-Skip": []string{"skip"}, + "Connection": []string{"keep-alive"}, + }, + }) + + headers := <-seen + assert.Equal(t, true, got.OK) + assert.Equal(t, "Bearer token", headers.Get("Authorization")) + assert.Equal(t, "trace-1", headers.Get("X-Trace")) + assert.Equal(t, "", headers.Get("X-Skip")) + assert.Equal(t, "", headers.Get("Connection")) + assert.Equal(t, "application/json", headers.Get("Content-Type")) +} + +func TestExecuteSemaphoreBoundsConcurrency(t *testing.T) { + var active atomic.Int32 + var maxActive atomic.Int32 + started := make(chan struct{}, 5) + release := make(chan struct{}) + client := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + now := active.Add(1) + for { + max := maxActive.Load() + if now <= max || maxActive.CompareAndSwap(max, now) { + break + } + } + started <- struct{}{} + <-release + active.Add(-1) + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: ioNopCloser{bytes.NewBufferString(`{"data":{"ok":true}}`)}, + }, nil + })} + s := newTestSandbox(t, "http://router/graphql", lookup{ + "ping": {Name: "ping", Body: "query Ping { ping }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { + cfg.MaxConcurrent = 4 + cfg.HTTPClient = client + }) + + var wg sync.WaitGroup + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Execute(context.Background(), ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"ping"}, + WrappedJS: `async () => await tools.ping()`, + }) + assert.NoError(t, err) + }() + } + + for range 4 { + <-started + } + assert.Equal(t, int32(4), maxActive.Load()) + assert.Equal(t, int32(4), active.Load()) + select { + case <-started: + t.Fatal("fifth Execute entered before a semaphore slot was released") + default: + } + close(release) + wg.Wait() + assert.Equal(t, int32(4), maxActive.Load()) +} + +func TestExecuteFrozenToolsAssignmentThrowsInStrictMode(t *testing.T) { + s := newTestSandbox(t, "", lookup{ + "foo": {Name: "foo", Body: "query Foo { foo }", Kind: storage.OperationKindQuery}, + }, nil) + + got := execute(t, s, ExecuteRequest{ToolNames: []string{"foo"}, WrappedJS: `async () => { + tools.foo = () => null; + return tools.foo === null; +}`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "TypeError", got.Error.Name) +} + +func TestExecuteUnknownToolName(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => await tools.nope()`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "TypeError", got.Error.Name) + // qjs reports native missing-method calls in this form for plain objects. + assert.Equal(t, "tools.nope is not a function", got.Error.Message) +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +type ioNopCloser struct { + *bytes.Buffer +} + +func (c ioNopCloser) Close() error { + return nil +} diff --git a/router/internal/codemode/sandbox/semaphore.go b/router/internal/codemode/sandbox/semaphore.go new file mode 100644 index 0000000000..3677255c7b --- /dev/null +++ b/router/internal/codemode/sandbox/semaphore.go @@ -0,0 +1,16 @@ +package sandbox + +import "context" + +func (s *Sandbox) acquire(ctx context.Context) error { + select { + case s.sem <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (s *Sandbox) release() { + <-s.sem +} diff --git a/router/internal/codemode/sandbox/validation.go b/router/internal/codemode/sandbox/validation.go new file mode 100644 index 0000000000..8353bc8790 --- /dev/null +++ b/router/internal/codemode/sandbox/validation.go @@ -0,0 +1,98 @@ +package sandbox + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/fastschema/qjs" +) + +const validationHelpers = ` +globalThis.__codemodeNormalizeError = (err, depth = 0) => { + if (!err) return null; + if (depth >= 5) return { name: "TruncatedCause", message: "cause chain exceeded depth 5", stack: "" }; + return { + name: err?.name ?? "Error", + message: err?.message ?? String(err), + stack: err?.stack ?? "", + cause: err?.cause ? __codemodeNormalizeError(err.cause, depth + 1) : null, + }; +}; +globalThis.__codemodeNormalizeErrorJSON = (err) => JSON.stringify(__codemodeNormalizeError(err)); + +globalThis.__codemodeValidateResult = (value) => { + const bad = []; + const seen = new WeakSet(); + const keyPath = (base, key) => { + if (typeof key === "number") return base + "[" + key + "]"; + return /^[A-Za-z_$][A-Za-z0-9_$]*$/.test(key) ? base + "." + key : base + "[" + JSON.stringify(key) + "]"; + }; + const walk = (v, path) => { + const t = typeof v; + if (t === "bigint" || t === "function" || t === "symbol" || t === "undefined") { + bad.push(path); + return; + } + if (v && t === "object") { + if (seen.has(v)) { + bad.push(path); + return; + } + seen.add(v); + if (Array.isArray(v)) { + for (let i = 0; i < v.length; i++) walk(v[i], keyPath(path, i)); + return; + } + for (const k of Object.keys(v)) walk(v[k], keyPath(path, k)); + } + }; + walk(value, "$"); + if (bad.length) return JSON.stringify({ serializable: false, paths: bad }); + try { + const json = JSON.stringify(value); + if (json === undefined) return JSON.stringify({ serializable: false, paths: ["$"] }); + return JSON.stringify({ serializable: true, json }); + } catch (err) { + return JSON.stringify({ serializable: false, paths: ["$"] }); + } +}; +` + +type validationOutcome struct { + Serializable bool `json:"serializable"` + JSON string `json:"json"` + Paths []string `json:"paths"` +} + +func installValidationHelpers(ctx *qjs.Context) error { + val, err := ctx.Eval("codemode_validation.js", qjs.Code(validationHelpers)) + _ = val + return err +} + +func validateResult(ctx *qjs.Context, result *qjs.Value, maxOutputBytes int) (json.RawMessage, *ErrorEnvelope, error) { + global := ctx.Global() + validator := global.GetPropertyStr("__codemodeValidateResult") + encoded, err := ctx.Invoke(validator, global, result) + if err != nil { + return nil, nil, err + } + + var outcome validationOutcome + if err := json.Unmarshal([]byte(encoded.String()), &outcome); err != nil { + return nil, nil, err + } + if !outcome.Serializable { + message := "return value contains non-JSON-serializable values at " + strings.Join(outcome.Paths, ", ") + return nil, &ErrorEnvelope{Name: "NotSerializable", Message: message, Stack: ""}, nil + } + if len(outcome.JSON) > maxOutputBytes { + return nil, &ErrorEnvelope{ + Name: "OutputTooLarge", + Message: fmt.Sprintf("encoded result size %d bytes exceeds limit %d bytes", len(outcome.JSON), maxOutputBytes), + Stack: "", + }, nil + } + return json.RawMessage(outcome.JSON), nil, nil +} diff --git a/router/internal/codemode/server/approval.go b/router/internal/codemode/server/approval.go new file mode 100644 index 0000000000..81d937c38a --- /dev/null +++ b/router/internal/codemode/server/approval.go @@ -0,0 +1,195 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "unicode/utf8" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/wundergraph/cosmo/router/internal/codemode/observability" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +const defaultMutationDeclinedReason = "Mutation declined by operator" + +// Elicitor is the testable subset of the MCP elicitation API used by mutation approval. +type Elicitor interface { + Elicit(ctx context.Context, params ElicitParams) (ElicitResponse, error) +} + +type ElicitParams struct { + Message string + RequestedSchema any +} + +type ElicitResponse struct { + Action string + FormData map[string]any +} + +type ElicitationGate struct { + elicitor Elicitor + logger *zap.Logger +} + +func NewElicitationGate(elicitor Elicitor, logger *zap.Logger) *ElicitationGate { + if logger == nil { + logger = zap.NewNop() + } + return &ElicitationGate{elicitor: elicitor, logger: logger} +} + +func (g *ElicitationGate) Decide(ctx context.Context, req sandbox.ApprovalRequest) (sandbox.ApprovalDecision, error) { + if g == nil || g.elicitor == nil { + decision := unsupportedElicitationDecision(errors.New("elicitor is not configured")) + recordMutationApproval(ctx, decision) + observability.LogElicitationOutcome(g.logger, SessionIDFromContext(ctx), decision.Approved, decision.Reason) + return decision, nil + } + + resp, err := g.elicitor.Elicit(ctx, ElicitParams{ + Message: mutationApprovalMessage(req), + RequestedSchema: mutationApprovalSchema(), + }) + if err != nil { + decision := unsupportedElicitationDecision(err) + recordMutationApproval(ctx, decision) + observability.LogElicitationOutcome(g.logger, SessionIDFromContext(ctx), decision.Approved, decision.Reason) + return decision, nil + } + + decision := decisionFromElicitation(resp) + recordMutationApproval(ctx, decision) + observability.LogElicitationOutcome(g.logger, SessionIDFromContext(ctx), decision.Approved, decision.Reason) + return decision, nil +} + +type MCPElicitor struct { + session *mcp.ServerSession +} + +func NewMCPElicitor(session *mcp.ServerSession) *MCPElicitor { + return &MCPElicitor{session: session} +} + +func (e *MCPElicitor) Elicit(ctx context.Context, params ElicitParams) (ElicitResponse, error) { + if e == nil || e.session == nil { + return ElicitResponse{}, errors.New("MCP server session is not available") + } + resp, err := e.session.Elicit(ctx, &mcp.ElicitParams{ + Message: params.Message, + RequestedSchema: params.RequestedSchema, + }) + if err != nil { + return ElicitResponse{}, err + } + if resp == nil { + return ElicitResponse{}, nil + } + return ElicitResponse{Action: resp.Action, FormData: resp.Content}, nil +} + +func decisionFromElicitation(resp ElicitResponse) sandbox.ApprovalDecision { + if resp.Action != "accept" || resp.FormData == nil { + return sandbox.ApprovalDecision{Approved: false, Reason: defaultMutationDeclinedReason} + } + if approved, ok := resp.FormData["approved"].(bool); ok && approved { + return sandbox.ApprovalDecision{Approved: true} + } + reason, _ := resp.FormData["reason"].(string) + return sandbox.ApprovalDecision{Approved: false, Reason: sanitizeMutationApprovalReason(reason)} +} + +func unsupportedElicitationDecision(err error) sandbox.ApprovalDecision { + return sandbox.ApprovalDecision{ + Approved: false, + Reason: fmt.Sprintf("mutation approval is required but the MCP client does not support elicitation: %s", err), + } +} + +func mutationApprovalSchema() map[string]any { + return map[string]any{ + "type": "object", + "required": []string{"approved"}, + "properties": map[string]any{ + "approved": map[string]any{"type": "boolean"}, + "reason": map[string]any{"type": "string", "maxLength": 500}, + }, + } +} + +func mutationApprovalMessage(req sandbox.ApprovalRequest) string { + return fmt.Sprintf( + "Approve GraphQL mutation %q?\n\nGraphQL mutation:\n\n%s\n\nVariables:\n\n%s", + req.Name, + prettyMutationSource(req.Source), + prettyMutationVariables(req.Vars), + ) +} + +// prettyMutationSource reformats a GraphQL operation body with two-space indentation. +// On any parse failure the original source is returned verbatim — operator-visible +// readability is best-effort, and we never want to swallow what they actually approve. +func prettyMutationSource(source string) string { + doc, report := astparser.ParseGraphqlDocumentString(source) + if report.HasErrors() { + return source + } + pretty, err := astprinter.PrintStringIndent(&doc, " ") + if err != nil { + return source + } + return pretty +} + +func prettyMutationVariables(vars json.RawMessage) string { + if len(vars) == 0 { + return "{}" + } + var decoded any + if err := json.Unmarshal(vars, &decoded); err != nil { + return string(vars) + } + pretty, err := json.MarshalIndent(decoded, "", " ") + if err != nil { + return string(vars) + } + return string(pretty) +} + +func sanitizeMutationApprovalReason(reason string) string { + var b strings.Builder + for len(reason) > 0 { + r, size := utf8.DecodeRuneInString(reason) + if r == utf8.RuneError && size == 1 { + reason = reason[size:] + continue + } + if r < 0x20 { + reason = reason[size:] + continue + } + if b.Len()+size > 500 { + break + } + b.WriteString(reason[:size]) + reason = reason[size:] + } + return b.String() +} + +func recordMutationApproval(ctx context.Context, decision sandbox.ApprovalDecision) { + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.Bool("code_mode.mutation.approved", decision.Approved), + attribute.String("code_mode.mutation.reason", decision.Reason), + ) +} diff --git a/router/internal/codemode/server/approval_test.go b/router/internal/codemode/server/approval_test.go new file mode 100644 index 0000000000..e15bb3a76e --- /dev/null +++ b/router/internal/codemode/server/approval_test.go @@ -0,0 +1,150 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + "unicode/utf8" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "go.uber.org/zap" +) + +type fakeElicitor struct { + response ElicitResponse + err error + params ElicitParams +} + +func (f *fakeElicitor) Elicit(ctx context.Context, params ElicitParams) (ElicitResponse, error) { + f.params = params + if f.err != nil { + return ElicitResponse{}, f.err + } + return f.response, nil +} + +func TestElicitationGateAcceptApprovedTrue(t *testing.T) { + elicitor := &fakeElicitor{ + response: ElicitResponse{Action: "accept", FormData: map[string]any{"approved": true}}, + } + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{ + Name: "deleteOrders", + Source: "mutation DeleteOrders { deleteOrders }", + Vars: json.RawMessage(`{"id":"x"}`), + }) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: true, Reason: ""}, got) + assert.Equal(t, map[string]any{ + "type": "object", + "required": []string{"approved"}, + "properties": map[string]any{ + "approved": map[string]any{"type": "boolean"}, + "reason": map[string]any{"type": "string", "maxLength": 500}, + }, + }, elicitor.params.RequestedSchema) + assert.Equal(t, "Approve GraphQL mutation \"deleteOrders\"?\n\nGraphQL mutation:\n\nmutation DeleteOrders {\n deleteOrders\n}\n\nVariables:\n\n{\n \"id\": \"x\"\n}", elicitor.params.Message) +} + +func TestElicitationGateAcceptApprovedFalseUsesReason(t *testing.T) { + elicitor := &fakeElicitor{ + response: ElicitResponse{Action: "accept", FormData: map[string]any{"approved": false, "reason": "no thanks"}}, + } + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "no thanks"}, got) +} + +func TestElicitationGateAcceptApprovedFalseStripsControlCharacters(t *testing.T) { + elicitor := &fakeElicitor{ + response: ElicitResponse{Action: "accept", FormData: map[string]any{"approved": false, "reason": "no\x00 \x01thanks\x1f"}}, + } + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "no thanks"}, got) +} + +func TestElicitationGateAcceptApprovedFalseTruncatesReasonUTF8Safely(t *testing.T) { + elicitor := &fakeElicitor{ + response: ElicitResponse{Action: "accept", FormData: map[string]any{"approved": false, "reason": strings.Repeat("é", 300)}}, + } + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: strings.Repeat("é", 250)}, got) + assert.Equal(t, 500, len(got.Reason)) + assert.Equal(t, true, utf8.ValidString(got.Reason)) +} + +func TestElicitationGateDeclineAction(t *testing.T) { + elicitor := &fakeElicitor{response: ElicitResponse{Action: "decline"}} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "Mutation declined by operator"}, got) +} + +func TestElicitationGateCancelAction(t *testing.T) { + elicitor := &fakeElicitor{response: ElicitResponse{Action: "cancel"}} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "Mutation declined by operator"}, got) +} + +func TestElicitationGateUnsupportedElicitationErrorDeclines(t *testing.T) { + elicitor := &fakeElicitor{err: errors.New("elicitation not supported")} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{ + Approved: false, + Reason: "mutation approval is required but the MCP client does not support elicitation: elicitation not supported", + }, got) +} + +func TestElicitationGateContextCanceledErrorDeclines(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + elicitor := &fakeElicitor{err: ctx.Err()} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(ctx, sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{ + Approved: false, + Reason: "mutation approval is required but the MCP client does not support elicitation: context canceled", + }, got) +} + +func TestElicitationGateAcceptWithoutFormDataDeclines(t *testing.T) { + elicitor := &fakeElicitor{response: ElicitResponse{Action: "accept"}} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "Mutation declined by operator"}, got) +} diff --git a/router/internal/codemode/server/execute_handler.go b/router/internal/codemode/server/execute_handler.go new file mode 100644 index 0000000000..70234b3525 --- /dev/null +++ b/router/internal/codemode/server/execute_handler.go @@ -0,0 +1,101 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/observability" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +type executeAPIInput struct { + Source string `json:"source"` +} + +func (s *Server) handleExecuteAPI(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx = contextWithSessionFromExtra(ctx, req.GetExtra()) + + source, err := decodeExecuteSource(req) + if err != nil { + return toolErrorResult(err.Error()), nil + } + + if !s.namedOpsEnabled || s.sessionStateless { + return toolErrorResult(namedOpsDisabledMessage), nil + } + + sessionID := SessionIDFromContext(ctx) + if sessionID == "" { + return toolErrorResult(namedOpsDisabledMessage), nil + } + if s.storage == nil { + return toolErrorResult("code_mode_run_js: storage is not configured"), nil + } + if s.pipeline == nil { + return toolErrorResult("code_mode_run_js: pipeline failed: code mode execute pipeline is not configured"), nil + } + + names, err := s.storage.ListNames(ctx, sessionID) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_run_js: failed to list tools: %v", err)), nil + } + + executeTimeout := s.executeTimeout + if executeTimeout <= 0 { + executeTimeout = defaultExecuteTimeout + } + execCtx, cancel := context.WithTimeout(ctx, executeTimeout) + defer cancel() + + response, err := s.pipeline.Execute(execCtx, harness.PipelineRequest{ + SessionID: sessionID, + ToolNames: names, + Source: source, + RequestHeaders: requestHeaders(req), + ApprovalGate: s.approvalGateForRequest(req), + }) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_run_js: pipeline failed: %v", err)), nil + } + if response.Envelope.Error != nil && response.Envelope.Error.Name == "TranspileError" { + observability.LogTranspileFailure(s.logger, sessionID, response.Envelope.Error.Message) + } + return textResult(string(response.Encoded)), nil +} + +func decodeExecuteSource(req *mcp.CallToolRequest) (string, error) { + var input executeAPIInput + if req != nil && req.Params != nil && len(req.Params.Arguments) > 0 { + if err := json.Unmarshal(req.Params.Arguments, &input); err != nil { + return "", errors.New("code_mode_run_js: source must be a non-empty string") + } + } + if strings.TrimSpace(input.Source) == "" { + return "", errors.New("code_mode_run_js: source must be a non-empty string") + } + return input.Source, nil +} + +func (s *Server) approvalGateForRequest(req *mcp.CallToolRequest) sandbox.ApprovalGate { + if s.approvalGate != nil { + return s.approvalGate + } + var session *mcp.ServerSession + if req != nil { + session = req.Session + } + return NewElicitationGate(NewMCPElicitor(session), s.logger) +} + +func requestHeaders(req *mcp.CallToolRequest) http.Header { + if req == nil || req.GetExtra() == nil { + return nil + } + return req.GetExtra().Header.Clone() +} diff --git a/router/internal/codemode/server/execute_handler_test.go b/router/internal/codemode/server/execute_handler_test.go new file mode 100644 index 0000000000..57cef16d6e --- /dev/null +++ b/router/internal/codemode/server/execute_handler_test.go @@ -0,0 +1,431 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "sync" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +func TestHandleExecuteValidatesSource(t *testing.T) { + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: &recordingPipeline{}, + }, newExecuteTestStorage()) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "", + })) + + require.NoError(t, err) + assert.Equal(t, toolError("code_mode_run_js: source must be a non-empty string"), got) +} + +func TestHandleExecuteNamedOpsDisabled(t *testing.T) { + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Pipeline: &recordingPipeline{}, + }, newExecuteTestStorage()) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => null", + })) + + require.NoError(t, err) + assert.Equal(t, toolError("named operations are disabled"), got) +} + +func TestHandleExecuteStatelessDisablesNamedOps(t *testing.T) { + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: true, + Pipeline: &recordingPipeline{}, + }, newExecuteTestStorage()) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => null", + })) + + require.NoError(t, err) + assert.Equal(t, toolError("named operations are disabled"), got) +} + +func TestHandleExecuteStatefulHappyPathReturnsEncodedEnvelope(t *testing.T) { + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{ + Name: "someName", + Body: "query SomeName { orders { id total } }", + Kind: storage.OperationKindQuery, + }} + pipeline := &recordingPipeline{ + response: pipelineResponse(t, harness.ResultEnvelope{ + Result: json.RawMessage(`{"orders":[{"id":"o1","total":12.5}]}`), + Truncated: false, + Error: nil, + }), + } + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: pipeline, + ApprovalGate: sandbox.AutoApprove, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => { const r = await tools.someName({}); return r.data; }", + })) + + require.NoError(t, err) + assert.Equal(t, textToolResult(string(pipeline.response.Encoded)), got) + assert.Equal(t, harness.PipelineRequest{ + SessionID: "session-1", + ToolNames: []string{ + "someName", + }, + Source: "async () => { const r = await tools.someName({}); return r.data; }", + RequestHeaders: http.Header{ + mcpSessionIDHeader: []string{"session-1"}, + "X-Test": []string{"yes"}, + }, + ApprovalGate: sandbox.AutoApprove, + }, pipeline.lastRequest()) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(pipeline.response.Encoded, &decoded)) + assert.Equal(t, map[string]any{ + "result": map[string]any{ + "orders": []any{ + map[string]any{"id": "o1", "total": 12.5}, + }, + }, + }, decoded) +} + +func TestHandleExecuteSandboxErrorEnvelopeReturnsAsText(t *testing.T) { + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{Name: "someName", Body: "query SomeName { orders { id } }", Kind: storage.OperationKindQuery}} + pipeline := &recordingPipeline{ + response: pipelineResponse(t, harness.ResultEnvelope{ + Result: json.RawMessage("null"), + Truncated: false, + Error: &harness.ErrorEnvelope{Name: "RuntimeError", Message: "boom", Stack: "stack"}, + }), + } + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: pipeline, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => { throw new Error('boom'); }", + })) + + require.NoError(t, err) + assert.Equal(t, textToolResult(string(pipeline.response.Encoded)), got) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(pipeline.response.Encoded, &decoded)) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "RuntimeError", + "message": "boom", + "stack": "stack", + }, + }, decoded) +} + +func TestHandleExecutePerCallTimeoutRoutesEnvelope(t *testing.T) { + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{Name: "someName", Body: "query SomeName { orders { id } }", Kind: storage.OperationKindQuery}} + pipeline := &recordingPipeline{sleep: 100 * time.Millisecond} + pipeline.onCancel = pipelineResponse(t, harness.ResultEnvelope{ + Result: json.RawMessage("null"), + Truncated: false, + Error: &harness.ErrorEnvelope{Name: "Timeout", Message: "context deadline exceeded", Stack: ""}, + }) + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + ExecuteTimeout: 10 * time.Millisecond, + Pipeline: pipeline, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => tools.someName({})", + })) + + require.NoError(t, err) + assert.Equal(t, textToolResult(string(pipeline.onCancel.Encoded)), got) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(pipeline.onCancel.Encoded, &decoded)) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "Timeout", + "message": "context deadline exceeded", + "stack": "", + }, + }, decoded) +} + +func TestHandleExecuteTranspileErrorEnvelopeReturnsAsText(t *testing.T) { + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{Name: "someName", Body: "query SomeName { orders { id } }", Kind: storage.OperationKindQuery}} + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: &harness.Pipeline{}, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => { let x = ; }", + })) + + require.NoError(t, err) + require.Len(t, got.Content, 1) + text, ok := got.Content[0].(*mcp.TextContent) + require.True(t, ok) + + var decoded map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &decoded)) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "TranspileError", + "message": "transpile failed: Unexpected \";\"", + "stack": "", + }, + }, decoded) +} + +func TestPersistedOpsResourceReturnsCumulativeBundle(t *testing.T) { + schema := searchHandlerTestSchema(t) + store := storage.NewMemoryBackend(storage.MemoryConfig{Renderer: tsgen.Adapter(schema, 0)}) + store.SetSchema(schema) + _, err := store.Append(context.Background(), "session-1", []storage.SessionOp{{ + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + Kind: storage.OperationKindQuery, + Description: "Fetch orders.", + }}) + require.NoError(t, err) + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Storage: store, + Pipeline: &recordingPipeline{}, + }, nil) + + got, err := srv.handlePersistedOpsResource(context.Background(), resourceRequest("session-1")) + + require.NoError(t, err) + wantBundle, err := tsgen.RenderBundle([]storage.SessionOp{{ + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + Kind: storage.OperationKindQuery, + Description: "Fetch orders.", + }}, schema, 0) + require.NoError(t, err) + assert.Equal(t, &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: wantBundle, + }}, + }, got) +} + +func TestPersistedOpsResourceWithoutSessionReturnsEmptyBundle(t *testing.T) { + schema := searchHandlerTestSchema(t) + store := storage.NewMemoryBackend(storage.MemoryConfig{Renderer: tsgen.Adapter(schema, 0)}) + store.SetSchema(schema) + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Storage: store, + Pipeline: &recordingPipeline{}, + }, nil) + + got, err := srv.handlePersistedOpsResource(context.Background(), resourceRequest("")) + + require.NoError(t, err) + wantBundle, err := tsgen.RenderBundle(nil, schema, 0) + require.NoError(t, err) + assert.Equal(t, &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: wantBundle, + }}, + }, got) +} + +type recordingPipeline struct { + mu sync.Mutex + requests []harness.PipelineRequest + response harness.PipelineResponse + onCancel harness.PipelineResponse + sleep time.Duration + err error + lastSpan trace.SpanContext +} + +func (p *recordingPipeline) Execute(ctx context.Context, req harness.PipelineRequest) (harness.PipelineResponse, error) { + p.mu.Lock() + p.requests = append(p.requests, req) + p.lastSpan = trace.SpanFromContext(ctx).SpanContext() + p.mu.Unlock() + + if p.sleep > 0 { + select { + case <-ctx.Done(): + return p.onCancel, nil + case <-time.After(p.sleep): + } + } + if p.err != nil { + return harness.PipelineResponse{}, p.err + } + return p.response, nil +} + +func (p *recordingPipeline) lastRequest() harness.PipelineRequest { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.requests) == 0 { + return harness.PipelineRequest{} + } + return p.requests[len(p.requests)-1] +} + +func (p *recordingPipeline) lastSpanContext() trace.SpanContext { + p.mu.Lock() + defer p.mu.Unlock() + return p.lastSpan +} + +type executeTestStorage struct { + mu sync.Mutex + ops map[string][]storage.SessionOp +} + +func newExecuteTestStorage() *executeTestStorage { + return &executeTestStorage{ops: make(map[string][]storage.SessionOp)} +} + +func (s *executeTestStorage) Append(_ context.Context, sessionID string, ops []storage.SessionOp) ([]storage.SessionOp, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.ops[sessionID] = append(s.ops[sessionID], ops...) + return ops, nil +} + +func (s *executeTestStorage) GetOp(_ context.Context, sessionID string, name string) (storage.SessionOp, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, op := range s.ops[sessionID] { + if op.Name == name { + return op, true, nil + } + } + return storage.SessionOp{}, false, nil +} + +func (s *executeTestStorage) ListNames(_ context.Context, sessionID string) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + names := make([]string, 0, len(s.ops[sessionID])) + for _, op := range s.ops[sessionID] { + names = append(names, op.Name) + } + return names, nil +} + +func (s *executeTestStorage) Bundle(context.Context, string) (string, error) { + return "", nil +} + +func (s *executeTestStorage) Reset(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.ops, sessionID) + return nil +} + +func (s *executeTestStorage) SetSchema(*ast.Document) {} + +func (s *executeTestStorage) Schema() *ast.Document { return nil } + +func (s *executeTestStorage) Start(context.Context) error { return nil } + +func (s *executeTestStorage) Stop() error { return nil } + +func pipelineResponse(t *testing.T, envelope harness.ResultEnvelope) harness.PipelineResponse { + t.Helper() + encoded, err := json.Marshal(envelope) + require.NoError(t, err) + return harness.PipelineResponse{Envelope: envelope, Encoded: encoded} +} + +func executeToolRequest(t *testing.T, sessionID string, arguments map[string]any) *mcp.CallToolRequest { + t.Helper() + body, err := json.Marshal(arguments) + require.NoError(t, err) + return &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "code_mode_run_js", + Arguments: body, + }, + Extra: &mcp.RequestExtra{Header: http.Header{ + mcpSessionIDHeader: []string{sessionID}, + "X-Test": []string{"yes"}, + }}, + } +} + +func resourceRequest(sessionID string) *mcp.ReadResourceRequest { + return &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{URI: persistedOpsURI}, + Extra: &mcp.RequestExtra{Header: http.Header{mcpSessionIDHeader: []string{sessionID}}}, + } +} + +func newExecuteTestServer(t *testing.T, cfg Config, store storage.SessionStorage) *Server { + t.Helper() + if store != nil { + cfg.Storage = store + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + srv, err := New(cfg) + require.NoError(t, err) + return srv +} diff --git a/router/internal/codemode/server/lifecycle.go b/router/internal/codemode/server/lifecycle.go new file mode 100644 index 0000000000..42d3be7600 --- /dev/null +++ b/router/internal/codemode/server/lifecycle.go @@ -0,0 +1,182 @@ +package server + +import ( + "context" + "fmt" + "net/http" + + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" + "github.com/wundergraph/cosmo/router/internal/codemode/yoko" + "github.com/wundergraph/cosmo/router/internal/rediscloser" + "github.com/wundergraph/cosmo/router/pkg/config" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +type BuildOptions struct { + Config config.MCPCodeModeConfiguration + SessionStateless bool + RouterGraphQLURL string + Logger *zap.Logger + TracerProvider trace.TracerProvider + MeterProvider metric.MeterProvider + + // RedisProvider is the resolved storage_providers.redis entry referenced by + // cfg.NamedOps.Storage.ProviderID. When nil, the in-memory backend is used. + // Provider lookup (and the "unknown id" error) is performed by the router. + RedisProvider *config.RedisStorageProvider + // RedisFactory is an optional override used by tests. When nil, the default + // rediscloser.NewRedisCloser is used. + RedisFactory func(opts *rediscloser.RedisCloserOptions) (rediscloser.RDCloser, error) +} + +func BuildFromConfig(opts BuildOptions) (*Server, error) { + logger := opts.Logger + if logger == nil { + logger = zap.NewNop() + } + + cfg := opts.Config + if !cfg.Enabled { + return New(Config{ + ListenAddr: cfg.Server.ListenAddr, + CodeModeEnabled: cfg.Enabled, + NamedOpsEnabled: cfg.NamedOps.Enabled, + SessionStateless: opts.SessionStateless, + ExecuteTimeout: cfg.ExecuteTimeout, + MaxResultBytes: cfg.MaxResultBytes, + Logger: logger, + TracerProvider: opts.TracerProvider, + MeterProvider: opts.MeterProvider, + ApprovalGate: sandbox.AutoApprove, + CallTraceRecorder: nil, + }) + } + + renderer := tsgen.Adapter(nil, cfg.NamedOps.MaxBundleBytes) + store, err := buildStorage(cfg, renderer, opts, logger) + if err != nil { + return nil, err + } + + sbx, err := sandbox.New(sandbox.Config{ + RouterGraphQLEndpoint: opts.RouterGraphQLURL, + RequestTimeout: cfg.Sandbox.Timeout, + MemoryLimitBytes: cfg.Sandbox.MaxMemoryMB * 1024 * 1024, + MaxInputSizeBytes: cfg.Sandbox.MaxInputSizeBytes, + MaxOutputSizeBytes: cfg.Sandbox.MaxOutputSizeBytes, + MaxResultBytes: cfg.MaxResultBytes, + StorageLookup: func(ctx context.Context, sessionID string, name string) (storage.SessionOp, bool, error) { + if store == nil { + return storage.SessionOp{}, false, nil + } + return store.GetOp(ctx, sessionID, name) + }, + Logger: logger, + }) + if err != nil { + return nil, fmt.Errorf("create code mode sandbox: %w", err) + } + + return New(Config{ + ListenAddr: cfg.Server.ListenAddr, + CodeModeEnabled: cfg.Enabled, + NamedOpsEnabled: cfg.NamedOps.Enabled, + SessionStateless: opts.SessionStateless, + Storage: store, + Pipeline: &harness.Pipeline{Sandbox: sbx, MaxInputBytes: cfg.Sandbox.MaxInputSizeBytes, MaxResultBytes: cfg.MaxResultBytes}, + YokoClient: buildYokoClient(cfg.QueryGeneration, logger), + BundleRenderer: renderer, + ExecuteTimeout: cfg.ExecuteTimeout, + MaxResultBytes: cfg.MaxResultBytes, + ApprovalGate: buildApprovalGate(cfg, logger), + Logger: logger, + MeterProvider: opts.MeterProvider, + TracerProvider: opts.TracerProvider, + CallTraceRecorder: nil, + }) +} + +func buildStorage(cfg config.MCPCodeModeConfiguration, renderer storage.Renderer, opts BuildOptions, logger *zap.Logger) (storage.SessionStorage, error) { + if !cfg.NamedOps.Enabled { + return nil, nil + } + + if opts.RedisProvider == nil { + return storage.NewMemoryBackend(storage.MemoryConfig{ + SessionTTL: cfg.NamedOps.SessionTTL, + MaxSessions: cfg.NamedOps.MaxSessions, + MaxBundleBytes: cfg.NamedOps.MaxBundleBytes, + Renderer: renderer, + }), nil + } + + factory := opts.RedisFactory + if factory == nil { + factory = rediscloser.NewRedisCloser + } + client, err := factory(&rediscloser.RedisCloserOptions{ + Logger: logger, + URLs: opts.RedisProvider.URLs, + ClusterEnabled: opts.RedisProvider.ClusterEnabled, + }) + if err != nil { + return nil, fmt.Errorf("create code mode redis storage client: %w", err) + } + backend, err := storage.NewRedisBackend(storage.RedisConfig{ + Client: client, + KeyPrefix: cfg.NamedOps.Storage.KeyPrefix, + SessionTTL: cfg.NamedOps.SessionTTL, + Renderer: renderer, + Logger: logger, + }) + if err != nil { + return nil, fmt.Errorf("create code mode redis storage backend: %w", err) + } + return backend, nil +} + +func buildYokoClient(cfg config.MCPCodeModeQueryGenConfig, logger *zap.Logger) *yoko.Client { + if !cfg.Enabled { + return nil + } + client := &http.Client{Timeout: cfg.Timeout} + if token := cfg.Auth.StaticToken; cfg.Auth.Type == "" || cfg.Auth.Type == "static" { + if token != "" { + client.Transport = staticBearerRoundTripper{ + token: token, + next: http.DefaultTransport, + } + } + } else if cfg.Auth.Type == "jwt" { + logger.Warn("code mode query generation jwt auth is not implemented; proceeding without auth") + } + return yoko.New(client, cfg.Endpoint, logger) +} + +func buildApprovalGate(cfg config.MCPCodeModeConfiguration, _ *zap.Logger) sandbox.ApprovalGate { + if cfg.RequireMutationApproval { + return nil + } + return sandbox.AutoApprove +} + +type staticBearerRoundTripper struct { + token string + next http.RoundTripper +} + +func (t staticBearerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + next := t.next + if next == nil { + next = http.DefaultTransport + } + cloned := req.Clone(req.Context()) + cloned.Header = req.Header.Clone() + cloned.Header.Set("Authorization", "Bearer "+t.token) + return next.RoundTrip(cloned) +} diff --git a/router/internal/codemode/server/lifecycle_test.go b/router/internal/codemode/server/lifecycle_test.go new file mode 100644 index 0000000000..349e2f1419 --- /dev/null +++ b/router/internal/codemode/server/lifecycle_test.go @@ -0,0 +1,206 @@ +package server + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/yoko" + "github.com/wundergraph/cosmo/router/internal/rediscloser" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" + "go.uber.org/zap" +) + +func TestBuildFromConfigDisabledIsNoOp(t *testing.T) { + srv, err := BuildFromConfig(BuildOptions{ + Config: config.MCPCodeModeConfiguration{Enabled: false}, + SessionStateless: false, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + require.NoError(t, srv.Start(context.Background())) + assert.Equal(t, "", srv.addr()) + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + require.NoError(t, srv.Stop(context.Background())) +} + +func TestBuildFromConfigMemoryBackendReloadsSchemaAndSDL(t *testing.T) { + cfg := fullLifecycleConfig() + srv, err := BuildFromConfig(BuildOptions{ + Config: cfg, + SessionStateless: false, + RouterGraphQLURL: "http://router.local/graphql", + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + backend, ok := srv.storage.(*storage.MemoryBackend) + require.True(t, ok) + + schema := lifecycleTestSchema(t) + require.NoError(t, srv.Reload(schema, "type Query { orders: [Order!]! }")) + + assert.Equal(t, schema, backend.Schema()) + client, ok := srv.yokoClient.(*yoko.Client) + require.True(t, ok) + assert.Equal(t, "type Query { orders: [Order!]! }", client.Schema()) +} + +func TestBuildFromConfigRedisFactoryError(t *testing.T) { + cfg := fullLifecycleConfig() + cfg.NamedOps.Storage.ProviderID = "my_redis" + + srv, err := BuildFromConfig(BuildOptions{ + Config: cfg, + SessionStateless: false, + RouterGraphQLURL: "http://router.local/graphql", + Logger: zap.NewNop(), + RedisProvider: &config.RedisStorageProvider{ + ID: "my_redis", + URLs: []string{"redis://127.0.0.1:6379"}, + }, + RedisFactory: func(*rediscloser.RedisCloserOptions) (rediscloser.RDCloser, error) { + return nil, errors.New("redis unavailable") + }, + }) + + require.Nil(t, srv) + require.ErrorContains(t, err, "create code mode redis storage client: redis unavailable") +} + +func TestBuildFromConfigRedisBackendWithMiniredis(t *testing.T) { + mr, err := miniredis.Run() + if err != nil { + if isBindPermissionError(err) { + t.Skipf("local miniredis bind is not permitted in this environment: %v", err) + } + require.NoError(t, err) + } + t.Cleanup(mr.Close) + var gotOpts rediscloser.RedisCloserOptions + var client *redis.Client + t.Cleanup(func() { + if client != nil { + require.NoError(t, client.Close()) + } + }) + + cfg := fullLifecycleConfig() + cfg.NamedOps.Storage.ProviderID = "my_redis" + cfg.NamedOps.Storage.KeyPrefix = "test_code_mode" + + srv, err := BuildFromConfig(BuildOptions{ + Config: cfg, + SessionStateless: false, + RouterGraphQLURL: "http://router.local/graphql", + Logger: zap.NewNop(), + RedisProvider: &config.RedisStorageProvider{ + ID: "my_redis", + URLs: []string{"redis://" + mr.Addr()}, + ClusterEnabled: true, + }, + RedisFactory: func(opts *rediscloser.RedisCloserOptions) (rediscloser.RDCloser, error) { + gotOpts = *opts + client = redis.NewClient(&redis.Options{Addr: mr.Addr()}) + return client, nil + }, + }) + require.NoError(t, err) + + _, ok := srv.storage.(*storage.RedisBackend) + require.True(t, ok) + assert.NotNil(t, gotOpts.Logger) + assert.Equal(t, []string{"redis://" + mr.Addr()}, gotOpts.URLs) + assert.Equal(t, true, gotOpts.ClusterEnabled) +} + +func TestBuildFromConfigReloadEvictsMemorySessions(t *testing.T) { + srv, err := BuildFromConfig(BuildOptions{ + Config: fullLifecycleConfig(), + SessionStateless: false, + RouterGraphQLURL: "http://router.local/graphql", + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + _, err = srv.storage.Append(context.Background(), "session-1", []storage.SessionOp{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: storage.OperationKindQuery, + Description: "Fetch orders.", + }}) + require.NoError(t, err) + + _, ok, err := srv.storage.GetOp(context.Background(), "session-1", "getOrders") + require.NoError(t, err) + assert.Equal(t, true, ok) + + require.NoError(t, srv.Reload(lifecycleTestSchema(t), "type Query { customer: Customer }")) + + got, ok, err := srv.storage.GetOp(context.Background(), "session-1", "getOrders") + require.NoError(t, err) + assert.Equal(t, false, ok) + assert.Equal(t, storage.SessionOp{}, got) +} + +func TestBuildFromConfigDisabledReloadIsNoOp(t *testing.T) { + srv, err := BuildFromConfig(BuildOptions{ + Config: config.MCPCodeModeConfiguration{Enabled: false}, + SessionStateless: false, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(lifecycleTestSchema(t), "type Query { orders: [Order!]! }")) + assert.Nil(t, srv.storage) + assert.Nil(t, srv.yokoClient) +} + +func fullLifecycleConfig() config.MCPCodeModeConfiguration { + return config.MCPCodeModeConfiguration{ + Enabled: true, + Server: config.MCPCodeModeServerConfig{ListenAddr: "127.0.0.1:0"}, + RequireMutationApproval: true, + ExecuteTimeout: 120 * time.Second, + MaxResultBytes: 32 << 10, + Sandbox: config.MCPCodeModeSandboxConfig{ + Timeout: 5 * time.Second, + MaxMemoryMB: 16, + MaxInputSizeBytes: 64 << 10, + MaxOutputSizeBytes: 1 << 20, + }, + QueryGeneration: config.MCPCodeModeQueryGenConfig{ + Enabled: true, + Endpoint: "http://yoko.local", + Timeout: 10 * time.Second, + Auth: config.MCPCodeModeQueryGenAuthConfig{Type: "static", StaticToken: "token"}, + }, + NamedOps: config.MCPCodeModeNamedOpsConfig{ + Enabled: true, + SessionTTL: 30 * time.Minute, + MaxSessions: 1000, + MaxBundleBytes: 256 << 10, + Storage: config.MCPCodeModeNamedOpsStorageConfig{ + KeyPrefix: "cosmo_code_mode", + }, + }, + } +} + +func lifecycleTestSchema(t *testing.T) *ast.Document { + t.Helper() + doc, report := astparser.ParseGraphqlDocumentString(searchHandlerTestSchemaSDL) + require.False(t, report.HasErrors(), report.Error()) + require.NoError(t, asttransform.MergeDefinitionWithBaseSchema(&doc)) + return &doc +} diff --git a/router/internal/codemode/server/observability_handler_test.go b/router/internal/codemode/server/observability_handler_test.go new file mode 100644 index 0000000000..8a4621326e --- /dev/null +++ b/router/internal/codemode/server/observability_handler_test.go @@ -0,0 +1,180 @@ +package server + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +func TestHandleSearchRecordsObservability(t *testing.T) { + traces, meterProvider, reader := newHandlerTelemetry() + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + }}} + store := newSearchTestStorage(t) + srv, err := New(Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + Storage: store, + YokoClient: searcher, + Logger: zap.NewNop(), + TracerProvider: sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(traces)), + MeterProvider: meterProvider, + }) + require.NoError(t, err) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + require.False(t, got.IsError) + assert.Equal(t, []tracetest.SpanStub{{ + Name: "MCP Code Mode - Search", + SpanKind: trace.SpanKindServer, + Attributes: []attribute.KeyValue{ + attribute.String("mcp.tool", "code_mode_search_tools"), + attribute.String("mcp.status", "success"), + }, + InstrumentationLibrary: normalizedSpanStubs(traces.Ended())[0].InstrumentationLibrary, + }}, normalizedSpanStubs(traces.Ended())) + assertCodeModeMetric(t, reader, "code_mode_search_tools", "success") +} + +func TestHandleExecuteRecordsObservability(t *testing.T) { + traces, meterProvider, reader := newHandlerTelemetry() + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{ + Name: "someName", + Body: "query SomeName { orders { id total } }", + Kind: storage.OperationKindQuery, + }} + pipeline := &recordingPipeline{ + response: pipelineResponse(t, harness.ResultEnvelope{ + Result: json.RawMessage(`{"orders":[{"id":"o1"}]}`), + Truncated: false, + Error: nil, + }), + } + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: pipeline, + ApprovalGate: sandbox.AutoApprove, + Logger: zap.NewNop(), + TracerProvider: sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(traces)), + MeterProvider: meterProvider, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => tools.someName({})", + })) + + require.NoError(t, err) + require.False(t, got.IsError) + assert.Equal(t, []tracetest.SpanStub{{ + Name: "MCP Code Mode - Execute", + SpanKind: trace.SpanKindServer, + Attributes: []attribute.KeyValue{ + attribute.String("mcp.tool", "code_mode_run_js"), + attribute.String("mcp.status", "success"), + }, + InstrumentationLibrary: normalizedSpanStubs(traces.Ended())[0].InstrumentationLibrary, + }}, normalizedSpanStubs(traces.Ended())) + assertCodeModeMetric(t, reader, "code_mode_run_js", "success") + require.True(t, pipeline.lastSpanContext().IsValid()) +} + +func newHandlerTelemetry() (*tracetest.SpanRecorder, *sdkmetric.MeterProvider, *sdkmetric.ManualReader) { + reader := sdkmetric.NewManualReader() + return tracetest.NewSpanRecorder(), sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)), reader +} + +func normalizedSpanStubs(spans []sdktrace.ReadOnlySpan) []tracetest.SpanStub { + stubs := make([]tracetest.SpanStub, 0, len(spans)) + for _, span := range spans { + stub := tracetest.SpanStubFromReadOnlySpan(span) + stub.SpanContext = trace.SpanContext{} + stub.StartTime = time.Time{} + stub.EndTime = time.Time{} + stub.Resource = nil + stubs = append(stubs, stub) + } + return stubs +} + +func assertCodeModeMetric(t *testing.T, reader *sdkmetric.ManualReader, toolName string, status string) { + t.Helper() + var got metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &got)) + + counter, histogram := handlerCodeModeMetrics(t, got) + counterData, ok := counter.Data.(metricdata.Sum[int64]) + require.True(t, ok) + require.Len(t, counterData.DataPoints, 1) + counterPoint := counterData.DataPoints[0] + counterPoint.StartTime = time.Time{} + counterPoint.Time = time.Time{} + assert.Equal(t, metricdata.DataPoint[int64]{ + Attributes: attribute.NewSet( + attribute.String("mcp.tool", toolName), + attribute.String("mcp.status", status), + ), + Value: 1, + }, counterPoint) + + histogramData, ok := histogram.Data.(metricdata.Histogram[float64]) + require.True(t, ok) + require.Len(t, histogramData.DataPoints, 1) + histogramPoint := histogramData.DataPoints[0] + require.Greater(t, histogramPoint.Sum, 0.0) + histogramPoint.StartTime = time.Time{} + histogramPoint.Time = time.Time{} + assert.Equal(t, metricdata.HistogramDataPoint[float64]{ + Attributes: attribute.NewSet( + attribute.String("mcp.tool", toolName), + attribute.String("mcp.status", status), + ), + Count: 1, + Bounds: histogramPoint.Bounds, + BucketCounts: histogramPoint.BucketCounts, + Min: histogramPoint.Min, + Max: histogramPoint.Max, + Sum: histogramPoint.Sum, + }, histogramPoint) +} + +func handlerCodeModeMetrics(t *testing.T, metrics metricdata.ResourceMetrics) (metricdata.Metrics, metricdata.Metrics) { + t.Helper() + require.Len(t, metrics.ScopeMetrics, 1) + assert.Equal(t, "wundergraph.cosmo.router.mcp.code_mode", metrics.ScopeMetrics[0].Scope.Name) + + byName := make(map[string]metricdata.Metrics, len(metrics.ScopeMetrics[0].Metrics)) + for _, metric := range metrics.ScopeMetrics[0].Metrics { + byName[metric.Name] = metric + } + counter, ok := byName["mcp.code_mode.sandbox.executions"] + require.True(t, ok) + histogram, ok := byName["mcp.code_mode.sandbox.duration"] + require.True(t, ok) + return counter, histogram +} diff --git a/router/internal/codemode/server/search_handler.go b/router/internal/codemode/server/search_handler.go new file mode 100644 index 0000000000..8860cc3eca --- /dev/null +++ b/router/internal/codemode/server/search_handler.go @@ -0,0 +1,264 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.uber.org/zap" +) + +const ( + maxSearchPrompts = 20 + emptySearchAPIResponseMessage = "// 0 new ops; previous code_mode_search_tools calls already cover these prompts." + + // The generated proto currently has query and mutation constants. Yoko may + // still send the planned subscription enum value; host behavior is to drop it. + yokoOperationKindSubscription yokov1.OperationKind = 3 +) + +type searchAPIInput struct { + Prompts []string `json:"prompts"` +} + +type legacyCatalogueOperation struct { + Name string `json:"name"` + Body string `json:"body"` + Kind string `json:"kind"` + Description string `json:"description"` + Variables *string `json:"variables"` +} + +func (s *Server) handleSearchAPI(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx = contextWithSessionFromExtra(ctx, req.GetExtra()) + + prompts, validationErr := decodeSearchPrompts(req) + if validationErr != nil { + return toolErrorResult(validationErr.Error()), nil + } + + if s.sessionStateless { + return s.handleSearchStateless(ctx, prompts), nil + } + + sessionID := SessionIDFromContext(ctx) + if sessionID == "" { + s.warnMissingSessionIDOnce() + return s.handleSearchStateless(ctx, prompts), nil + } + + key := searchSingleFlightKey(sessionID, prompts) + value, _, _ := s.searchGroup.Do(key, func() (any, error) { + return s.handleSearchStateful(ctx, sessionID, prompts), nil + }) + return value.(*mcp.CallToolResult), nil +} + +func decodeSearchPrompts(req *mcp.CallToolRequest) ([]string, error) { + var input searchAPIInput + if req != nil && req.Params != nil && len(req.Params.Arguments) > 0 { + if err := json.Unmarshal(req.Params.Arguments, &input); err != nil { + return nil, errors.New("code_mode_search_tools: prompts must be a non-empty array of strings") + } + } + + if len(input.Prompts) == 0 { + return nil, errors.New("code_mode_search_tools: prompts must be a non-empty array of strings") + } + if len(input.Prompts) > maxSearchPrompts { + return nil, fmt.Errorf("too many prompts: %d (max 20) — pass all prompts in one call", len(input.Prompts)) + } + for i, prompt := range input.Prompts { + if strings.TrimSpace(prompt) == "" { + return nil, fmt.Errorf("code_mode_search_tools: prompt at index %d is empty", i) + } + } + return input.Prompts, nil +} + +func (s *Server) handleSearchStateless(ctx context.Context, prompts []string) *mcp.CallToolResult { + response, err := s.searchYoko(ctx, "", prompts) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: yoko search failed: %v", err)) + } + + catalogue := make([]legacyCatalogueOperation, 0, len(response.GetOperations())) + droppedSubscription := false + for _, op := range response.GetOperations() { + kind, ok, subscription := yokoOperationKindLabel(op.GetKind()) + if subscription { + droppedSubscription = true + continue + } + if !ok { + s.logger.Warn("code_mode_search_tools dropped unsupported operation kind", + zap.String("name", op.GetName()), + zap.String("kind", op.GetKind().String()), + ) + continue + } + catalogue = append(catalogue, legacyCatalogueOperation{ + Name: op.GetName(), + Body: op.GetBody(), + Kind: kind, + Description: op.GetDescription(), + Variables: extractGraphQLVariablesBlock(op.GetBody()), + }) + } + if droppedSubscription { + s.logger.Warn("code_mode_search_tools dropped subscription operations returned by yoko") + } + + encoded, err := json.Marshal(catalogue) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to encode legacy catalogue: %v", err)) + } + return textResult(string(encoded)) +} + +func (s *Server) handleSearchStateful(ctx context.Context, sessionID string, prompts []string) *mcp.CallToolResult { + response, err := s.searchYoko(ctx, sessionID, prompts) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: yoko search failed: %v", err)) + } + + rawOps := make([]storage.SessionOp, 0, len(response.GetOperations())) + droppedSubscription := false + for _, op := range response.GetOperations() { + kind, ok, subscription := storageOperationKind(op.GetKind()) + if subscription { + droppedSubscription = true + continue + } + if !ok { + s.logger.Warn("code_mode_search_tools dropped unsupported operation kind", + zap.String("name", op.GetName()), + zap.String("kind", op.GetKind().String()), + ) + continue + } + rawOps = append(rawOps, storage.SessionOp{ + Name: storage.NormalizeName(op.GetName()), + Body: op.GetBody(), + Kind: kind, + Description: op.GetDescription(), + }) + } + if droppedSubscription { + s.logger.Warn("code_mode_search_tools dropped subscription operations returned by yoko") + } + + if len(rawOps) == 0 { + return textResult(emptySearchAPIResponseMessage) + } + if s.storage == nil { + return toolErrorResult("code_mode_search_tools: failed to register ops: code mode storage is not configured") + } + + // Collision handling approach: Append-applies-suffix. The storage backend is + // the serialization point for a session and returns the final stored names. + appendedOps, err := s.storage.Append(ctx, sessionID, rawOps) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to register ops: %v", err)) + } + if len(appendedOps) == 0 { + return textResult(emptySearchAPIResponseMessage) + } + + rendered, err := s.newOpsFragment(appendedOps, s.storage.Schema()) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to render new ops: %v", err)) + } + return textResult(rendered) +} + +func (s *Server) searchYoko(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { + if s.yokoClient == nil { + return nil, errors.New("yoko client is not configured") + } + return s.yokoClient.Search(ctx, sessionID, prompts) +} + +func storageOperationKind(kind yokov1.OperationKind) (storage.OperationKind, bool, bool) { + switch kind { + case yokov1.OperationKind_OPERATION_KIND_QUERY: + return storage.OperationKindQuery, true, false + case yokov1.OperationKind_OPERATION_KIND_MUTATION: + return storage.OperationKindMutation, true, false + case yokoOperationKindSubscription: + return "", false, true + default: + return "", false, false + } +} + +func yokoOperationKindLabel(kind yokov1.OperationKind) (string, bool, bool) { + switch kind { + case yokov1.OperationKind_OPERATION_KIND_QUERY: + return "Query", true, false + case yokov1.OperationKind_OPERATION_KIND_MUTATION: + return "Mutation", true, false + case yokoOperationKindSubscription: + return "", false, true + default: + return "", false, false + } +} + +func searchSingleFlightKey(sessionID string, prompts []string) string { + sortedPrompts := append([]string(nil), prompts...) + sort.Strings(sortedPrompts) + keyParts := []string{sessionID} + for _, p := range sortedPrompts { + keyParts = append(keyParts, fmt.Sprintf("%d:%s", len(p), p)) + } + return strings.Join(keyParts, "|") +} + +func extractGraphQLVariablesBlock(body string) *string { + open := strings.IndexByte(body, '(') + if open < 0 { + return nil + } + selection := strings.IndexByte(body, '{') + if selection >= 0 && selection < open { + return nil + } + + depth := 0 + for i := open; i < len(body); i++ { + switch body[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + value := strings.TrimSpace(body[open : i+1]) + return &value + } + } + } + return nil +} + +func textResult(text string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: text}}, + } +} + +func (s *Server) warnMissingSessionIDOnce() { + s.mu.Lock() + defer s.mu.Unlock() + if s.warnedMissingSessionID { + return + } + s.warnedMissingSessionID = true + s.logger.Warn("code mode code_mode_search_tools missing MCP session id; falling back to legacy stateless catalogue") +} diff --git a/router/internal/codemode/server/search_handler_test.go b/router/internal/codemode/server/search_handler_test.go new file mode 100644 index 0000000000..df4c9fdeff --- /dev/null +++ b/router/internal/codemode/server/search_handler_test.go @@ -0,0 +1,663 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" + "go.uber.org/zap" +) + +const searchHandlerTestSchemaSDL = ` +schema { + query: Query + mutation: Mutation +} + +type Query { + orders(limit: Int): [Order!]! + customer(id: ID!): Customer +} + +type Mutation { + cancelOrder(id: ID!): Order! +} + +type Order { + id: ID! + total: Float! +} + +type Customer { + id: ID! + name: String! +} +` + +const emptySearchMessage = "// 0 new ops; previous code_mode_search_tools calls already cover these prompts." + +func TestHandleSearchValidatesPrompts(t *testing.T) { + tests := []struct { + name string + arguments map[string]any + want string + }{ + { + name: "missing prompts", + arguments: map[string]any{}, + want: "code_mode_search_tools: prompts must be a non-empty array of strings", + }, + { + name: "empty prompts", + arguments: map[string]any{"prompts": []string{}}, + want: "code_mode_search_tools: prompts must be a non-empty array of strings", + }, + { + name: "too many prompts", + arguments: map[string]any{"prompts": func() []string { + prompts := make([]string, 21) + for i := range prompts { + prompts[i] = fmt.Sprintf("prompt %d", i) + } + return prompts + }()}, + want: "too many prompts: 21 (max 20) — pass all prompts in one call", + }, + { + name: "empty prompt", + arguments: map[string]any{"prompts": []string{"orders", " \t\n"}}, + want: "code_mode_search_tools: prompt at index 1 is empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := newSearchTestServer(t, false, newFakeYoko(), newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", tt.arguments)) + + require.NoError(t, err) + assert.Equal(t, toolError(tt.want), got) + }) + } +} + +func TestHandleSearchStatelessReturnsLegacyJSONCatalogue(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{ + { + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch orders.", + }, + { + Name: "watchOrders", + Body: "subscription WatchOrders { orders { id } }", + Kind: yokoOperationKindSubscription, + Description: "Watch orders.", + }, + }} + store := newSearchTestStorage(t) + srv := newSearchTestServer(t, true, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + expectedJSON := mustJSON(t, []legacyCatalogueEntry{ + { + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id } }", + Kind: "Query", + Description: "Fetch orders.", + Variables: ptrString("($limit: Int)"), + }, + }) + assert.Equal(t, textToolResult(expectedJSON), got) + assert.Equal(t, []searchCall{{sessionID: "", prompts: []string{"orders"}}}, searcher.callsSnapshot()) + assert.Equal(t, []storage.SessionOp(nil), store.opsSnapshot("session-1")) +} + +func TestHandleSearchStatefulAppendsAndReturnsNewOpsFragment(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{ + { + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch orders.", + }, + { + Name: "cancelOrder", + Body: "mutation CancelOrder($id: ID!) { cancelOrder(id: $id) { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_MUTATION, + Description: "Cancel an order.", + }, + { + Name: "watchOrders", + Body: "subscription WatchOrders { orders { id } }", + Kind: yokoOperationKindSubscription, + Description: "Watch orders.", + }, + }} + store := newSearchTestStorage(t) + srv := newSearchTestServer(t, false, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders", "cancel order"}, + })) + + require.NoError(t, err) + wantOps := []storage.SessionOp{ + { + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + Kind: storage.OperationKindQuery, + Description: "Fetch orders.", + }, + { + Name: "cancelOrder", + Body: "mutation CancelOrder($id: ID!) { cancelOrder(id: $id) { id } }", + Kind: storage.OperationKindMutation, + Description: "Cancel an order.", + }, + } + wantFragment, err := tsgen.NewOpsFragment(wantOps, searchHandlerTestSchema(t)) + require.NoError(t, err) + assert.Equal(t, textToolResult(wantFragment), got) + assert.Equal(t, wantOps, store.opsSnapshot("session-1")) + assert.Equal(t, []searchCall{{sessionID: "session-1", prompts: []string{"orders", "cancel order"}}}, searcher.callsSnapshot()) +} + +func TestHandleSearchFallsBackToStatelessWhenSessionIDMissing(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch orders.", + }}} + store := newSearchTestStorage(t) + srv := newSearchTestServer(t, false, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + expectedJSON := mustJSON(t, []legacyCatalogueEntry{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: "Query", + Description: "Fetch orders.", + Variables: nil, + }}) + assert.Equal(t, textToolResult(expectedJSON), got) + assert.Equal(t, []storage.SessionOp(nil), store.opsSnapshot("session-1")) +} + +func TestHandleSearchNamingCollisionUsesFinalStoredName(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrdersAgain { orders { total } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch order totals.", + }}} + store := newSearchTestStorage(t) + _, err := store.Append(context.Background(), "session-1", []storage.SessionOp{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: storage.OperationKindQuery, + }}) + require.NoError(t, err) + srv := newSearchTestServer(t, false, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders again"}, + })) + + require.NoError(t, err) + wantOps := []storage.SessionOp{ + {Name: "getOrders", Body: "query GetOrders { orders { id } }", Kind: storage.OperationKindQuery}, + {Name: "getOrders_2", Body: "query GetOrdersAgain { orders { total } }", Kind: storage.OperationKindQuery, Description: "Fetch order totals."}, + } + wantFragment, err := tsgen.NewOpsFragment(wantOps[1:], searchHandlerTestSchema(t)) + require.NoError(t, err) + assert.Equal(t, textToolResult(wantFragment), got) + assert.Equal(t, wantOps, store.opsSnapshot("session-1")) +} + +func TestHandleSearchEmptyYokoResponseIsSuccess(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + assert.Equal(t, textToolResult(emptySearchMessage), got) +} + +func TestHandleSearchDoesNotRetryNotFoundFromSearcher(t *testing.T) { + searcher := newFakeYoko() + searcher.errs <- connect.NewError(connect.CodeNotFound, errors.New("missing index")) + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + assert.Equal(t, toolError("code_mode_search_tools: yoko search failed: not_found: missing index"), got) + assert.Equal(t, 1, searcher.callCount()) +} + +func TestHandleSearchYokoErrorIsToolError(t *testing.T) { + searcher := newFakeYoko() + searcher.errs <- errors.New("dial tcp: connection refused") + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + assert.Equal(t, toolError("code_mode_search_tools: yoko search failed: dial tcp: connection refused"), got) +} + +func TestHandleSearchSingleFlight(t *testing.T) { + t.Run("identical calls share leader result", func(t *testing.T) { + searcher := newFakeYoko() + searcher.block = make(chan struct{}) + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + }}} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + ctx := context.Background() + var wg sync.WaitGroup + results := make([]*mcp.CallToolResult, 2) + wg.Add(1) + go func() { + defer wg.Done() + result, err := srv.handleSearch(ctx, searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders", "customers"}, + })) + require.NoError(t, err) + results[0] = result + }() + require.Eventually(t, func() bool { return searcher.callCount() == 1 }, time.Second, time.Millisecond) + wg.Add(1) + go func() { + defer wg.Done() + result, err := srv.handleSearch(ctx, searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders", "customers"}, + })) + require.NoError(t, err) + results[1] = result + }() + time.Sleep(10 * time.Millisecond) + close(searcher.block) + wg.Wait() + + assert.Equal(t, 1, searcher.callCount()) + assert.Equal(t, results[0], results[1]) + }) + + t.Run("different calls do not share result", func(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{} + searcher.responses <- &yokov1.SearchResponse{} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + var wg sync.WaitGroup + for _, prompt := range []string{"orders", "customers"} { + wg.Add(1) + go func(prompt string) { + defer wg.Done() + _, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{prompt}, + })) + require.NoError(t, err) + }(prompt) + } + wg.Wait() + + assert.Equal(t, 2, searcher.callCount()) + }) + + t.Run("ambiguous spacing prompt sets do not share result", func(t *testing.T) { + searcher := newFakeYoko() + searcher.block = make(chan struct{}) + searcher.responses <- &yokov1.SearchResponse{} + searcher.responses <- &yokov1.SearchResponse{} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + var wg sync.WaitGroup + for _, prompts := range [][]string{ + {"a b", "c"}, + {"a", "b c"}, + } { + prompts := prompts + wg.Add(1) + go func() { + defer wg.Done() + _, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": prompts, + })) + require.NoError(t, err) + }() + } + + require.Eventually(t, func() bool { return searcher.callCount() == 2 }, time.Second, time.Millisecond) + close(searcher.block) + wg.Wait() + + assert.Equal(t, 2, searcher.callCount()) + }) +} + +func TestHandleSearchRenderErrorIsToolError(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + }}} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + srv.newOpsFragment = func([]storage.SessionOp, *ast.Document) (string, error) { + return "", errors.New("render exploded") + } + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + assert.Equal(t, toolError("code_mode_search_tools: failed to render new ops: render exploded"), got) +} + +func TestHandleSearchCancelMaySurfaceLeaderCancellationToFollower(t *testing.T) { + searcher := newFakeYoko() + searcher.block = make(chan struct{}) + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + leaderCtx, cancelLeader := context.WithCancel(context.Background()) + defer cancelLeader() + + var wg sync.WaitGroup + results := make([]*mcp.CallToolResult, 2) + wg.Add(1) + go func() { + defer wg.Done() + result, err := srv.handleSearch(leaderCtx, searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + require.NoError(t, err) + results[0] = result + }() + require.Eventually(t, func() bool { return searcher.callCount() == 1 }, time.Second, time.Millisecond) + + wg.Add(1) + go func() { + defer wg.Done() + result, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + require.NoError(t, err) + results[1] = result + }() + time.Sleep(10 * time.Millisecond) + cancelLeader() + close(searcher.block) + wg.Wait() + + assert.Equal(t, 1, searcher.callCount()) + assert.Equal(t, toolError("code_mode_search_tools: yoko search failed: context canceled"), results[0]) + assert.Equal(t, toolError("code_mode_search_tools: yoko search failed: context canceled"), results[1]) +} + +type searchCall struct { + sessionID string + prompts []string +} + +type fakeYoko struct { + mu sync.Mutex + calls []searchCall + responses chan *yokov1.SearchResponse + errs chan error + block chan struct{} + schema string +} + +func newFakeYoko() *fakeYoko { + return &fakeYoko{ + responses: make(chan *yokov1.SearchResponse, 16), + errs: make(chan error, 16), + } +} + +func (f *fakeYoko) Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { + f.mu.Lock() + f.calls = append(f.calls, searchCall{sessionID: sessionID, prompts: append([]string(nil), prompts...)}) + f.mu.Unlock() + + if f.block != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-f.block: + } + } + + select { + case err := <-f.errs: + return nil, err + default: + } + select { + case response := <-f.responses: + return response, nil + default: + return &yokov1.SearchResponse{}, nil + } +} + +func (f *fakeYoko) SetSchema(schema string) { + f.mu.Lock() + defer f.mu.Unlock() + f.schema = schema +} + +func (f *fakeYoko) Schema() string { + f.mu.Lock() + defer f.mu.Unlock() + return f.schema +} + +func (f *fakeYoko) callCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return len(f.calls) +} + +func (f *fakeYoko) callsSnapshot() []searchCall { + f.mu.Lock() + defer f.mu.Unlock() + calls := make([]searchCall, 0, len(f.calls)) + for _, call := range f.calls { + calls = append(calls, searchCall{sessionID: call.sessionID, prompts: append([]string(nil), call.prompts...)}) + } + return calls +} + +type searchTestStorage struct { + mu sync.Mutex + schema *ast.Document + ops map[string][]storage.SessionOp +} + +func newSearchTestStorage(t *testing.T) *searchTestStorage { + t.Helper() + return &searchTestStorage{ + schema: searchHandlerTestSchema(t), + ops: make(map[string][]storage.SessionOp), + } +} + +func (s *searchTestStorage) Append(ctx context.Context, sessionID string, ops []storage.SessionOp) ([]storage.SessionOp, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + + taken := make(map[string]struct{}, len(s.ops[sessionID])+len(ops)) + for _, op := range s.ops[sessionID] { + taken[op.Name] = struct{}{} + } + + appended := make([]storage.SessionOp, 0, len(ops)) + for _, op := range ops { + op.Name = storage.SuffixedName(storage.NormalizeName(op.Name), taken) + taken[op.Name] = struct{}{} + s.ops[sessionID] = append(s.ops[sessionID], op) + appended = append(appended, op) + } + return appended, nil +} + +func (s *searchTestStorage) GetOp(_ context.Context, sessionID string, name string) (storage.SessionOp, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, op := range s.ops[sessionID] { + if op.Name == name { + return op, true, nil + } + } + return storage.SessionOp{}, false, nil +} + +func (s *searchTestStorage) ListNames(_ context.Context, sessionID string) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + names := make([]string, 0, len(s.ops[sessionID])) + for _, op := range s.ops[sessionID] { + names = append(names, op.Name) + } + return names, nil +} + +func (s *searchTestStorage) Bundle(context.Context, string) (string, error) { + return "", nil +} + +func (s *searchTestStorage) Reset(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.ops, sessionID) + return nil +} + +func (s *searchTestStorage) SetSchema(schema *ast.Document) { + s.mu.Lock() + defer s.mu.Unlock() + s.schema = schema +} + +func (s *searchTestStorage) Schema() *ast.Document { + s.mu.Lock() + defer s.mu.Unlock() + return s.schema +} + +func (s *searchTestStorage) Start(context.Context) error { + return nil +} + +func (s *searchTestStorage) Stop() error { + return nil +} + +func (s *searchTestStorage) opsSnapshot(sessionID string) []storage.SessionOp { + s.mu.Lock() + defer s.mu.Unlock() + return append([]storage.SessionOp(nil), s.ops[sessionID]...) +} + +type legacyCatalogueEntry struct { + Name string `json:"name"` + Body string `json:"body"` + Kind string `json:"kind"` + Description string `json:"description"` + Variables *string `json:"variables"` +} + +func newSearchTestServer(t *testing.T, stateless bool, searcher *fakeYoko, store *searchTestStorage) *Server { + t.Helper() + srv, err := New(Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: stateless, + Storage: store, + YokoClient: searcher, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + }) + require.NoError(t, err) + return srv +} + +func searchToolRequest(t *testing.T, sessionID string, arguments map[string]any) *mcp.CallToolRequest { + t.Helper() + body, err := json.Marshal(arguments) + require.NoError(t, err) + return &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "code_mode_search_tools", + Arguments: body, + }, + Extra: &mcp.RequestExtra{Header: http.Header{mcpSessionIDHeader: []string{sessionID}}}, + } +} + +func textToolResult(text string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: text}}, + } +} + +func ptrString(value string) *string { + return &value +} + +func searchHandlerTestSchema(t *testing.T) *ast.Document { + t.Helper() + doc, report := astparser.ParseGraphqlDocumentString(searchHandlerTestSchemaSDL) + require.False(t, report.HasErrors(), report.Error()) + require.NoError(t, asttransform.MergeDefinitionWithBaseSchema(&doc)) + return &doc +} diff --git a/router/internal/codemode/server/server.go b/router/internal/codemode/server/server.go new file mode 100644 index 0000000000..17fbffa32a --- /dev/null +++ b/router/internal/codemode/server/server.go @@ -0,0 +1,458 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "net" + "net/http" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/wundergraph/cosmo/router/internal/codemode/calltrace" + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/observability" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" + "github.com/wundergraph/cosmo/router/internal/codemode/yoko" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + otelmetric "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" +) + +const ( + defaultListenAddr = "localhost:5027" + defaultExecuteTimeout = 120 * time.Second + defaultMaxResultBytes = 32 << 10 + mcpPath = "/mcp" + persistedOpsURI = "yoko://persisted-ops.d.ts" + statelessNamedOpsWarnMessage = "code mode named operations are disabled because MCP session stateless mode is enabled" + namedOpsDisabledMessage = "named operations are disabled" +) + +const searchAPIDescription = "Plan ALL data shapes you need up front, then call ONCE with every prompt in a single batch. Each extra search is a round-trip you pay for.\n\nDEFAULT TO ONE PROMPT. If the entities are related in any way — same domain, joinable, fetched together to answer one question, traversed via the same parent, or the user mentioned them in the same breath — combine them into a SINGLE prompt that describes the complete joined shape. Multiple prompts should be the exception, not the default.\n\nWrite each prompt as the COMPLETE final shape of data you want, including joins and correlation IDs. Yoko writes GraphQL across federated subgraphs, so a single prompt like \"employees with id, first name, last name, and their pets (name, type)\" returns one joined operation — never split this into \"list employees\" + \"list pets with owner\" that you'd then have to correlate in JS. If you DO split, every prompt MUST include the join keys (IDs / foreign keys) needed to correlate the results — otherwise the operations come back un-joinable and you'll have to search again.\n\nBE PRECISE about what you need. Vague prompts produce vague operations and force re-searches. Always state:\n- The exact fields you need on each entity (\"id, forename, surname\" — not \"name info\").\n- The relationships to traverse and how deep (\"employees with their pets and each pet's owner's department\").\n- Any required filters/arguments and the values or variable names (\"by id=42\", \"where status=ACTIVE\", \"limit 50\").\n- The shape of nested/related entities, field by field — do not say \"with all their data\".\n- Concrete entity and relationship names from the domain when you know them; otherwise describe the relationship explicitly (\"the team an employee belongs to\").\nA precise prompt: \"employee by id (variable: $id) returning id, forename, surname, role, and pets { name, type, age }\". A vague prompt: \"get employee details with related stuff\" — this will come back missing fields you need.\n\nWhen to use multiple prompts (rare): genuinely unrelated operations on disjoint domains, different argument shapes that can't share a parent, or queries vs mutations. Never slice one joinable shape into fragments. When in doubt, combine.\n\nDo NOT issue prompts for derived/computed values: averages, medians, counts, filters, exclusions (\"without X\"), sorting, top-N. Fetch the raw rows once and compute in code_mode_run_js. Yoko exposes data; arithmetic and reshaping happen in your JS.\n\nAnti-pattern: search → inspect result → notice a field or ID is missing → search again. One well-formed prompt beats three round-trips.\n\nThe response appends newly registered TypeScript declarations for use as `await tools.(vars)` inside code_mode_run_js; the cumulative bundle is available at `yoko://persisted-ops.d.ts`." + +const executeAPISourceDescription = "JavaScript source containing a single async arrow function. The host wraps it as `()()` and awaits the resulting Promise; the resolved JSON-serializable value is the tool result." + +const executeAPIDescription = "Run JavaScript source as a single async arrow function in the Code Mode sandbox. Use `await tools.(vars)` for operations registered by code_mode_search_tools; the cumulative tools namespace is available at `yoko://persisted-ops.d.ts`.\n\nStyle: write compact source — single line if it fits, no // comments, no blank lines, short variable names. The JSON wrapping that encodes your source charges you for every newline and indent space.\n\nBatch everything into ONE code_mode_run_js call. ≥3 `tools.*` invocations per call is normal; over-fetch and decide in JS, don't round-trip. A failing inner call degrades the result, not the whole script — wrap with try/catch and surface the error in the return value.\n\nThe return value of your async arrow is the only output channel — `console` is not available. To surface intermediate state, include it in the returned object (e.g. `return { result, debug: { ... } }`). For resilient fan-out use `Promise.allSettled` — `Promise.all` rejects on first failure and discards partial results. Up to 256 `tools.*` invocations per call. Return values must be JSON-serializable; `BigInt`, functions, symbols, and circular refs throw `NotSerializable`.\n\nExample: `async()=>{const o=await tools.getOrders({customerId:\"c_1\"});if(o.errors?.length)throw new Error(o.errors[0].message);return o.data.orders;}`\n\nType declarations for reference (consumed via `yoko://persisted-ops.d.ts`):\n\n```ts\ntype GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\ntype R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n\ndeclare const tools: {};\n\ndeclare function notNull(value: T | null | undefined, message?: string): T;\ndeclare function compact(value: T): T;\n```" + +// Config configures the Code Mode MCP server. +type Config struct { + ListenAddr string + CodeModeEnabled bool + NamedOpsEnabled bool + SessionStateless bool + Storage storage.SessionStorage + Pipeline harness.Executor + YokoClient yoko.Searcher + BundleRenderer storage.Renderer + ExecuteTimeout time.Duration + MaxResultBytes int + ApprovalGate sandbox.ApprovalGate + Logger *zap.Logger + MeterProvider otelmetric.MeterProvider + TracerProvider trace.TracerProvider + CallTraceRecorder calltrace.Recorder +} + +// Server owns the Code Mode MCP server and its separate HTTP listener. +type Server struct { + listenAddr string + codeModeEnabled bool + namedOpsEnabled bool + sessionStateless bool + storage storage.SessionStorage + pipeline harness.Executor + yokoClient yoko.Searcher + bundleRenderer storage.Renderer + executeTimeout time.Duration + maxResultBytes int + approvalGate sandbox.ApprovalGate + logger *zap.Logger + meter *observability.Meter + tracerProvider trace.TracerProvider + callTraceRecorder calltrace.Recorder + + mcpServer *mcp.Server + searchGroup singleflight.Group + newOpsFragment func([]storage.SessionOp, *ast.Document) (string, error) + + mu sync.Mutex + httpServer *http.Server + actualAddr string + warnedStatelessNamedOps bool + warnedMissingSessionID bool +} + +// New creates a Code Mode MCP server. +func New(cfg Config) (*Server, error) { + if cfg.ListenAddr == "" { + cfg.ListenAddr = defaultListenAddr + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + if cfg.MeterProvider == nil { + cfg.MeterProvider = otel.GetMeterProvider() + } + if cfg.TracerProvider == nil { + cfg.TracerProvider = otel.GetTracerProvider() + } + if cfg.CallTraceRecorder == nil { + cfg.CallTraceRecorder = calltrace.NopRecorder{} + } + if cfg.ExecuteTimeout <= 0 { + cfg.ExecuteTimeout = defaultExecuteTimeout + } + if cfg.MaxResultBytes <= 0 { + cfg.MaxResultBytes = defaultMaxResultBytes + } + if pipeline, ok := cfg.Pipeline.(*harness.Pipeline); ok { + pipeline.MaxResultBytes = cfg.MaxResultBytes + } + meter, err := observability.NewMeter(cfg.MeterProvider) + if err != nil { + return nil, err + } + + s := &Server{ + listenAddr: cfg.ListenAddr, + codeModeEnabled: cfg.CodeModeEnabled, + namedOpsEnabled: cfg.NamedOpsEnabled, + sessionStateless: cfg.SessionStateless, + storage: cfg.Storage, + pipeline: cfg.Pipeline, + yokoClient: cfg.YokoClient, + bundleRenderer: cfg.BundleRenderer, + executeTimeout: cfg.ExecuteTimeout, + maxResultBytes: cfg.MaxResultBytes, + approvalGate: cfg.ApprovalGate, + logger: cfg.Logger, + meter: meter, + tracerProvider: cfg.TracerProvider, + callTraceRecorder: cfg.CallTraceRecorder, + newOpsFragment: tsgen.NewOpsFragment, + } + + s.mcpServer = mcp.NewServer(&mcp.Implementation{ + Name: "yoko", + Title: "Yoko (Cosmo Code Mode)", + Version: "v0.1.0", + }, &mcp.ServerOptions{ + HasResources: true, + }) + + if cfg.CodeModeEnabled { + s.registerTools() + if cfg.NamedOpsEnabled && !cfg.SessionStateless { + s.registerPersistedOpsResource() + } + } + + return s, nil +} + +// Start binds the separate Code Mode MCP HTTP listener and serves until the +// server shuts down or ctx is canceled. When Code Mode is disabled it is a no-op. +func (s *Server) Start(ctx context.Context) error { + if !s.codeModeEnabled { + return nil + } + + if s.storage != nil { + if err := s.storage.Start(ctx); err != nil { + return err + } + } + + listener, err := net.Listen("tcp", s.listenAddr) + if err != nil { + if s.storage != nil { + _ = s.storage.Stop() + } + return err + } + + httpServer := &http.Server{ + Addr: s.listenAddr, + Handler: s.handler(), + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + s.mu.Lock() + s.httpServer = httpServer + s.actualAddr = listener.Addr().String() + s.mu.Unlock() + + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.Stop(shutdownCtx) + case <-done: + } + }() + + s.logger.Info("Code Mode MCP server started", zap.String("listen_addr", listener.Addr().String()), zap.String("path", mcpPath)) + err = httpServer.Serve(listener) + close(done) + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err +} + +// Stop gracefully shuts down the Code Mode MCP HTTP server. Disabled or unstarted +// servers are no-ops. +func (s *Server) Stop(ctx context.Context) error { + if !s.codeModeEnabled { + return nil + } + + s.mu.Lock() + httpServer := s.httpServer + s.mu.Unlock() + if httpServer == nil { + if s.storage != nil { + return s.storage.Stop() + } + return nil + } + err := httpServer.Shutdown(ctx) + if err == nil || errors.Is(err, http.ErrServerClosed) { + s.mu.Lock() + if s.httpServer == httpServer { + s.httpServer = nil + } + s.mu.Unlock() + if s.storage != nil { + return s.storage.Stop() + } + return nil + } + return err +} + +// Reload forwards schema state into Code Mode dependencies. Disabled servers +// ignore reloads. +func (s *Server) Reload(schema *ast.Document, sdl string) error { + if !s.codeModeEnabled { + return nil + } + if s.storage != nil { + s.storage.SetSchema(schema) + } + if s.yokoClient != nil { + s.yokoClient.SetSchema(sdl) + } + if s.sessionStateless && s.namedOpsEnabled { + s.warnStatelessNamedOpsOnce() + } + observability.LogSessionLifecycle(s.logger, "schema_swap", "", zap.Int("sdl_bytes", len(sdl))) + return nil +} + +func (s *Server) registerTools() { + s.mcpServer.AddTool(&mcp.Tool{ + Name: "code_mode_search_tools", + Description: searchAPIDescription, + InputSchema: searchAPIInputSchema(), + }, s.handleSearch) + + s.mcpServer.AddTool(&mcp.Tool{ + Name: "code_mode_run_js", + Description: executeAPIDescription, + InputSchema: executeAPIInputSchema(), + }, s.handleExecute) +} + +func (s *Server) registerPersistedOpsResource() { + s.mcpServer.AddResource(&mcp.Resource{ + URI: persistedOpsURI, + Name: "persisted-ops.d.ts", + Title: "Persisted operations TypeScript definitions", + Description: "Cumulative TypeScript definitions for the current Code Mode MCP session's named operations.", + MIMEType: "text/plain", + }, s.handlePersistedOpsResource) +} + +func (s *Server) handler() http.Handler { + cop := http.NewCrossOriginProtection() + cop.AddInsecureBypassPattern("/{path...}") + + streamableHTTPHandler := mcp.NewStreamableHTTPHandler( + func(*http.Request) *mcp.Server { + return s.mcpServer + }, + &mcp.StreamableHTTPOptions{ + Stateless: s.sessionStateless, + CrossOriginProtection: cop, + }, + ) + + mux := http.NewServeMux() + mux.Handle(mcpPath, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req = req.WithContext(withSessionIDFromRequest(req.Context(), req)) + streamableHTTPHandler.ServeHTTP(w, req) + })) + return mux +} + +func (s *Server) handleSearch(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return s.handleTool(ctx, req, "code_mode_search_tools", s.handleSearchAPI) +} + +func (s *Server) handleExecute(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return s.handleTool(ctx, req, "code_mode_run_js", s.handleExecuteAPI) +} + +func (s *Server) handleTool(ctx context.Context, req *mcp.CallToolRequest, toolName string, next func(context.Context, *mcp.CallToolRequest) (*mcp.CallToolResult, error)) (result *mcp.CallToolResult, err error) { + start := time.Now() + ctx, span := observability.StartToolSpanWithProvider(ctx, s.tracerProvider, toolName) + sessionID := sessionIDFromToolRequest(req) + if calltrace.Enabled(s.callTraceRecorder) { + s.callTraceRecorder.RecordRequest(toolName, toolRequestBody(req)) + } + observability.LogSessionLifecycle(s.logger, toolName+".started", sessionID) + defer func() { + status := toolStatus(result, err) + durationMs := float64(time.Since(start)) / float64(time.Millisecond) + span.SetAttributes(attribute.String("mcp.status", status)) + s.meter.Record(ctx, toolName, status, durationMs) + observability.LogSessionLifecycle(s.logger, toolName+".completed", sessionID, + zap.String("status", status), + zap.Float64("duration_ms", durationMs), + ) + span.End() + }() + + result, err = next(ctx, req) + if calltrace.Enabled(s.callTraceRecorder) { + if body, marshalErr := json.Marshal(result); marshalErr == nil { + s.callTraceRecorder.RecordResponse(toolName, body) + } + } + return result, err +} + +func toolStatus(result *mcp.CallToolResult, err error) string { + if err != nil || (result != nil && result.IsError) { + return "error" + } + return "success" +} + +func sessionIDFromToolRequest(req *mcp.CallToolRequest) string { + if req == nil || req.GetExtra() == nil { + return "" + } + return req.GetExtra().Header.Get(mcpSessionIDHeader) +} + +func toolRequestBody(req *mcp.CallToolRequest) []byte { + if req == nil || req.Params == nil || len(req.Params.Arguments) == 0 { + return []byte(`null`) + } + return append([]byte(nil), req.Params.Arguments...) +} + +func (s *Server) handlePersistedOpsResource(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + ctx = contextWithSessionFromExtra(ctx, req.GetExtra()) + if s.storage == nil { + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: "", + }}, + }, nil + } + bundle, err := s.storage.Bundle(ctx, SessionIDFromContext(ctx)) + if err != nil { + return nil, err + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: bundle, + }}, + }, nil +} + +func contextWithSessionFromExtra(ctx context.Context, extra *mcp.RequestExtra) context.Context { + if extra == nil { + return WithSessionID(ctx, "") + } + return WithSessionID(ctx, extra.Header.Get(mcpSessionIDHeader)) +} + +func toolErrorResult(message string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: message}}, + IsError: true, + } +} + +func searchAPIInputSchema() map[string]any { + return map[string]any{ + "type": "object", + "required": []any{"prompts"}, + "properties": map[string]any{ + "prompts": map[string]any{ + "type": "array", + "minItems": 1, + "maxItems": 20, + "items": map[string]any{ + "type": "string", + "minLength": 1, + }, + }, + }, + } +} + +func executeAPIInputSchema() map[string]any { + return map[string]any{ + "type": "object", + "required": []any{"source"}, + "properties": map[string]any{ + "source": map[string]any{ + "type": "string", + "minLength": 1, + "description": executeAPISourceDescription, + }, + }, + } +} + +func (s *Server) warnStatelessNamedOpsOnce() { + s.mu.Lock() + defer s.mu.Unlock() + if s.warnedStatelessNamedOps { + return + } + s.warnedStatelessNamedOps = true + s.logger.Warn(statelessNamedOpsWarnMessage) +} + +// Addr returns the listener address once Start has bound it. +func (s *Server) Addr() string { + return s.addr() +} + +func (s *Server) addr() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.actualAddr +} diff --git a/router/internal/codemode/server/server_test.go b/router/internal/codemode/server/server_test.go new file mode 100644 index 0000000000..65f153dfa7 --- /dev/null +++ b/router/internal/codemode/server/server_test.go @@ -0,0 +1,481 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "slices" + "sync" + "syscall" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/yoko" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +func TestStartDisabledReturnsWithoutBinding(t *testing.T) { + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: false, + Storage: newRecordingStorage(), + YokoClient: yoko.New(nil, "http://127.0.0.1", zap.NewNop()), + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + NamedOpsEnabled: true, + SessionStateless: false, + }) + require.NoError(t, err) + + err = srv.Start(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "", srv.addr()) + require.NoError(t, srv.Stop(context.Background())) +} + +func TestListToolsReturnsCodeModeTools(t *testing.T) { + srv := newTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + startServer(t, ctx, srv) + defer stopServer(t, srv) + + session := connectHTTPClient(t, ctx, "http://"+srv.addr()+"/mcp") + defer session.Close() + + got, err := session.ListTools(ctx, &mcp.ListToolsParams{}) + require.NoError(t, err) + require.Len(t, got.Tools, 2) + slices.SortFunc(got.Tools, func(a, b *mcp.Tool) int { + if a.Name < b.Name { + return -1 + } + if a.Name > b.Name { + return 1 + } + return 0 + }) + + assert.Equal(t, mustJSON(t, []*mcp.Tool{ + { + Name: "code_mode_run_js", + Description: executeAPIDescription, + InputSchema: map[string]any{ + "type": "object", + "required": []any{"source"}, + "properties": map[string]any{ + "source": map[string]any{ + "type": "string", + "minLength": float64(1), + "description": executeAPISourceDescription, + }, + }, + }, + }, + { + Name: "code_mode_search_tools", + Description: searchAPIDescription, + InputSchema: map[string]any{ + "type": "object", + "required": []any{"prompts"}, + "properties": map[string]any{ + "prompts": map[string]any{ + "type": "array", + "minItems": float64(1), + "maxItems": float64(20), + "items": map[string]any{ + "type": "string", + "minLength": float64(1), + }, + }, + }, + }, + }, + }), mustJSON(t, got.Tools)) +} + +func TestListResourcesGating(t *testing.T) { + tests := []struct { + name string + cfg Config + wantPresent bool + }{ + { + name: "code mode disabled", + cfg: Config{ + CodeModeEnabled: false, + NamedOpsEnabled: true, + SessionStateless: false, + }, + }, + { + name: "named ops disabled", + cfg: Config{ + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + }, + }, + { + name: "stateless disables named ops", + cfg: Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: true, + }, + }, + { + name: "stateful named ops", + cfg: Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + }, + wantPresent: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := newTestServer(t, tt.cfg) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + session := connectInMemoryClient(t, ctx, srv) + defer session.Close() + + got, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, tt.wantPresent, hasResource(got.Resources, persistedOpsURI)) + }) + } +} + +func TestStatelessNamedOpsReloadWarnsOnce(t *testing.T) { + core, recorded := observer.New(zap.WarnLevel) + store := newRecordingStorage() + client := yoko.New(nil, "http://127.0.0.1", zap.NewNop()) + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: true, + Storage: store, + YokoClient: client, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.New(core), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + + assert.Equal(t, 1, recorded.FilterMessage(statelessNamedOpsWarnMessage).Len()) +} + +func TestExecuteToolStubReturnsDeterministicToolError(t *testing.T) { + srv := newTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + startServer(t, ctx, srv) + defer stopServer(t, srv) + + session := connectHTTPClient(t, ctx, "http://"+srv.addr()+"/mcp") + defer session.Close() + + executeResult, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "code_mode_run_js", + Arguments: map[string]any{"source": "async () => null"}, + }) + require.NoError(t, err) + assert.Equal(t, mustJSON(t, toolError("named operations are disabled")), mustJSON(t, executeResult)) +} + +func TestSessionIDExtraction(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://example.com/mcp", nil) + require.NoError(t, err) + req.Header.Set("Mcp-Session-Id", "session-123") + + ctx := withSessionIDFromRequest(context.Background(), req) + + assert.Equal(t, "session-123", SessionIDFromContext(ctx)) + assert.Equal(t, "", SessionIDFromContext(context.Background())) + assert.Equal(t, "manual", SessionIDFromContext(WithSessionID(context.Background(), "manual"))) +} + +func TestResourceHandlerUsesCurrentSessionID(t *testing.T) { + store := newRecordingStorage() + store.bundle = "declare const tools: { getUser(): R<{ id: string }> };" + srv := newTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Storage: store, + }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + startServer(t, ctx, srv) + defer stopServer(t, srv) + + session := connectHTTPClient(t, ctx, "http://"+srv.addr()+"/mcp") + defer session.Close() + + got, err := session.ReadResource(ctx, &mcp.ReadResourceParams{URI: persistedOpsURI}) + require.NoError(t, err) + + require.NotEmpty(t, session.ID()) + assert.Equal(t, session.ID(), store.lastBundleSessionID()) + assert.Equal(t, mustJSON(t, &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: store.bundle, + }}, + }), mustJSON(t, got)) +} + +func TestReloadForwardsSchemaAndSDL(t *testing.T) { + store := newRecordingStorage() + client := yoko.New(nil, "http://127.0.0.1", zap.NewNop()) + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Storage: store, + YokoClient: client, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + schema := &ast.Document{} + require.NoError(t, srv.Reload(schema, "schema { query: Query }")) + + assert.Equal(t, schema, store.schema) + assert.Equal(t, 1, store.setSchemaCalls) + assert.Equal(t, "schema { query: Query }", client.Schema()) +} + +func TestReloadDisabledIsNoOp(t *testing.T) { + store := newRecordingStorage() + client := yoko.New(nil, "http://127.0.0.1", zap.NewNop()) + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: false, + NamedOpsEnabled: true, + SessionStateless: false, + Storage: store, + YokoClient: client, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + + assert.Equal(t, 0, store.setSchemaCalls) + assert.Equal(t, "", client.Schema()) +} + +func newTestServer(t *testing.T, cfg Config) *Server { + t.Helper() + if cfg.ListenAddr == "" { + cfg.ListenAddr = "127.0.0.1:0" + } + if cfg.Storage == nil { + cfg.Storage = newRecordingStorage() + } + if cfg.YokoClient == nil { + cfg.YokoClient = yoko.New(nil, "http://127.0.0.1", zap.NewNop()) + } + if cfg.BundleRenderer == nil { + cfg.BundleRenderer = storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }) + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + srv, err := New(cfg) + require.NoError(t, err) + return srv +} + +func startServer(t *testing.T, ctx context.Context, srv *Server) { + t.Helper() + errs := make(chan error, 1) + go func() { + errs <- srv.Start(ctx) + }() + deadline := time.After(5 * time.Second) + tick := time.NewTicker(10 * time.Millisecond) + defer tick.Stop() + bound := false + for { + select { + case err := <-errs: + if isBindPermissionError(err) { + t.Skipf("local listener bind is not permitted in this environment: %v", err) + } + require.NoError(t, err) + case <-deadline: + require.FailNow(t, "server listener was not bound") + case <-tick.C: + if srv.addr() != "" { + bound = true + } + } + if bound { + break + } + } + t.Cleanup(func() { + select { + case err := <-errs: + require.NoError(t, err) + case <-time.After(5 * time.Second): + require.FailNow(t, "server did not stop") + } + }) +} + +func isBindPermissionError(err error) bool { + return errors.Is(err, syscall.EACCES) || errors.Is(err, syscall.EPERM) +} + +func stopServer(t *testing.T, srv *Server) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(t, srv.Stop(ctx)) +} + +func connectHTTPClient(t *testing.T, ctx context.Context, endpoint string) *mcp.ClientSession { + t.Helper() + client := mcp.NewClient(&mcp.Implementation{Name: "code-mode-test-client", Version: "test"}, nil) + session, err := client.Connect(ctx, &mcp.StreamableClientTransport{ + Endpoint: endpoint, + DisableStandaloneSSE: true, + }, nil) + require.NoError(t, err) + return session +} + +func connectInMemoryClient(t *testing.T, ctx context.Context, srv *Server) *mcp.ClientSession { + t.Helper() + clientTransport, serverTransport := mcp.NewInMemoryTransports() + errs := make(chan error, 1) + go func() { + errs <- srv.mcpServer.Run(ctx, serverTransport) + }() + t.Cleanup(func() { + select { + case err := <-errs: + if err != nil && !errors.Is(err, context.Canceled) { + require.NoError(t, err) + } + default: + } + }) + + client := mcp.NewClient(&mcp.Implementation{Name: "code-mode-test-client", Version: "test"}, nil) + session, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + return session +} + +func hasResource(resources []*mcp.Resource, uri string) bool { + return slices.ContainsFunc(resources, func(resource *mcp.Resource) bool { + return resource.URI == uri + }) +} + +func mustJSON(t *testing.T, value any) string { + t.Helper() + data, err := json.Marshal(value) + require.NoError(t, err) + return string(data) +} + +func toolError(message string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: message}}, + IsError: true, + } +} + +type recordingStorage struct { + mu sync.Mutex + schema *ast.Document + setSchemaCalls int + bundle string + bundleSessionID string +} + +func newRecordingStorage() *recordingStorage { + return &recordingStorage{bundle: "declare const tools: {};"} +} + +func (s *recordingStorage) Append(_ context.Context, _ string, ops []storage.SessionOp) ([]storage.SessionOp, error) { + return ops, nil +} + +func (s *recordingStorage) GetOp(context.Context, string, string) (storage.SessionOp, bool, error) { + return storage.SessionOp{}, false, nil +} + +func (s *recordingStorage) ListNames(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *recordingStorage) Bundle(_ context.Context, sessionID string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.bundleSessionID = sessionID + return s.bundle, nil +} + +func (s *recordingStorage) Reset(context.Context, string) error { + return nil +} + +func (s *recordingStorage) SetSchema(schema *ast.Document) { + s.mu.Lock() + defer s.mu.Unlock() + s.schema = schema + s.setSchemaCalls++ +} + +func (s *recordingStorage) Schema() *ast.Document { + s.mu.Lock() + defer s.mu.Unlock() + return s.schema +} + +func (s *recordingStorage) Start(context.Context) error { + return nil +} + +func (s *recordingStorage) Stop() error { + return nil +} + +func (s *recordingStorage) lastBundleSessionID() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.bundleSessionID +} diff --git a/router/internal/codemode/server/session.go b/router/internal/codemode/server/session.go new file mode 100644 index 0000000000..ed2dcaeba5 --- /dev/null +++ b/router/internal/codemode/server/session.go @@ -0,0 +1,34 @@ +package server + +import ( + "context" + "net/http" +) + +const mcpSessionIDHeader = "Mcp-Session-Id" + +type sessionIDContextKey struct{} + +// SessionIDFromContext returns the MCP Streamable-HTTP session ID stored on ctx. +// An empty value is meaningful: it indicates stateless mode or a request without +// Mcp-Session-Id, and callers must not synthesize a replacement. +func SessionIDFromContext(ctx context.Context) string { + id, _ := ctx.Value(sessionIDContextKey{}).(string) + return id +} + +// WithSessionID stores id on ctx for Code Mode handlers. +func WithSessionID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, sessionIDContextKey{}, id) +} + +// withSessionIDFromRequest reads Mcp-Session-Id directly from the HTTP request. +// The modelcontextprotocol/go-sdk exposes transport headers to MCP handlers as +// req.Extra.Header; handlers call WithSessionID(ctx, req.Extra.Header.Get(...)). +// This helper is used for HTTP middleware/tests where the raw request is known. +func withSessionIDFromRequest(ctx context.Context, req *http.Request) context.Context { + if req == nil { + return WithSessionID(ctx, "") + } + return WithSessionID(ctx, req.Header.Get(mcpSessionIDHeader)) +} diff --git a/router/internal/codemode/storage/memory_backend.go b/router/internal/codemode/storage/memory_backend.go new file mode 100644 index 0000000000..7467d17635 --- /dev/null +++ b/router/internal/codemode/storage/memory_backend.go @@ -0,0 +1,354 @@ +package storage + +import ( + "context" + "errors" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +const ( + defaultSessionTTL = 30 * time.Minute + defaultMaxSessions = 1_000 + defaultMaxBundleBytes = 1 << 20 +) + +type MemoryConfig struct { + SessionTTL time.Duration + MaxSessions int + MaxBundleBytes int + Renderer Renderer + Now func() time.Time +} + +type MemoryBackend struct { + sessionTTL time.Duration + maxSessions int + maxBundleBytes int + renderer Renderer + now func() time.Time + + sessions sync.Map + + schemaMu sync.RWMutex + schema *ast.Document + + schemaVer atomic.Uint64 + + lifecycleMu sync.Mutex + cancel context.CancelFunc + done chan struct{} +} + +type memSession struct { + mu sync.Mutex + ops []SessionOp + lastUsed time.Time + bundle string + bundleValid bool +} + +type sessionSnapshot struct { + id string + lastUsed time.Time +} + +func NewMemoryBackend(config MemoryConfig) *MemoryBackend { + if config.SessionTTL <= 0 { + config.SessionTTL = defaultSessionTTL + } + if config.MaxSessions <= 0 { + config.MaxSessions = defaultMaxSessions + } + if config.MaxBundleBytes < 0 { + config.MaxBundleBytes = 0 + } + if config.MaxBundleBytes == 0 { + config.MaxBundleBytes = defaultMaxBundleBytes + } + if config.Now == nil { + config.Now = time.Now + } + + return &MemoryBackend{ + sessionTTL: config.SessionTTL, + maxSessions: config.MaxSessions, + maxBundleBytes: config.MaxBundleBytes, + renderer: config.Renderer, + now: config.Now, + } +} + +func (b *MemoryBackend) Append(ctx context.Context, sessionID string, ops []SessionOp) ([]SessionOp, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if len(ops) == 0 { + return nil, nil + } + + session := b.loadOrCreateSession(sessionID) + session.mu.Lock() + appended := make([]SessionOp, 0, len(ops)) + taken := make(map[string]struct{}, len(session.ops)+len(ops)) + for _, op := range session.ops { + taken[op.Name] = struct{}{} + } + for _, op := range ops { + op.Name = SuffixedName(NormalizeName(op.Name), taken) + taken[op.Name] = struct{}{} + session.ops = append(session.ops, op) + appended = append(appended, op) + } + session.lastUsed = b.now() + session.bundle = "" + session.bundleValid = false + session.mu.Unlock() + + b.enforceMaxSessions() + return appended, nil +} + +func (b *MemoryBackend) GetOp(ctx context.Context, sessionID string, name string) (SessionOp, bool, error) { + if err := ctx.Err(); err != nil { + return SessionOp{}, false, err + } + + value, ok := b.sessions.Load(sessionID) + if !ok { + return SessionOp{}, false, nil + } + session := value.(*memSession) + session.mu.Lock() + defer session.mu.Unlock() + + session.lastUsed = b.now() + for _, op := range session.ops { + if op.Name == name { + return op, true, nil + } + } + return SessionOp{}, false, nil +} + +func (b *MemoryBackend) ListNames(ctx context.Context, sessionID string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + value, ok := b.sessions.Load(sessionID) + if !ok { + return nil, nil + } + session := value.(*memSession) + session.mu.Lock() + defer session.mu.Unlock() + + session.lastUsed = b.now() + names := make([]string, 0, len(session.ops)) + for _, op := range session.ops { + names = append(names, op.Name) + } + return names, nil +} + +func (b *MemoryBackend) Bundle(ctx context.Context, sessionID string) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + value, ok := b.sessions.Load(sessionID) + if !ok { + return b.renderCapped(ctx, nil) + } + session := value.(*memSession) + + session.mu.Lock() + defer session.mu.Unlock() + + session.lastUsed = b.now() + if session.bundleValid { + return session.bundle, nil + } + + if b.renderer == nil { + return "", errors.New("code mode storage renderer is not configured") + } + + ops := append([]SessionOp(nil), session.ops...) + bundle, err := b.renderCapped(ctx, ops) + if err != nil { + return "", err + } + + session.bundle = bundle + session.bundleValid = true + return bundle, nil +} + +func (b *MemoryBackend) Reset(ctx context.Context, sessionID string) error { + if err := ctx.Err(); err != nil { + return err + } + b.sessions.Delete(sessionID) + return nil +} + +func (b *MemoryBackend) SetSchema(schema *ast.Document) { + b.schemaMu.Lock() + b.schema = schema + b.schemaMu.Unlock() + + b.schemaVer.Add(1) + b.clearSessions() +} + +func (b *MemoryBackend) Schema() *ast.Document { + b.schemaMu.RLock() + defer b.schemaMu.RUnlock() + return b.schema +} + +func (b *MemoryBackend) SchemaVersion() uint64 { + return b.schemaVer.Load() +} + +func (b *MemoryBackend) Start(ctx context.Context) error { + b.lifecycleMu.Lock() + defer b.lifecycleMu.Unlock() + + if b.cancel != nil { + return nil + } + + runCtx, cancel := context.WithCancel(ctx) + b.cancel = cancel + b.done = make(chan struct{}) + go b.runSweeper(runCtx, b.done) + return nil +} + +func (b *MemoryBackend) Stop() error { + b.lifecycleMu.Lock() + cancel := b.cancel + done := b.done + b.cancel = nil + b.done = nil + b.lifecycleMu.Unlock() + + if cancel == nil { + return nil + } + cancel() + <-done + return nil +} + +func (b *MemoryBackend) loadOrCreateSession(sessionID string) *memSession { + now := b.now() + session := &memSession{lastUsed: now} + value, _ := b.sessions.LoadOrStore(sessionID, session) + return value.(*memSession) +} + +func (b *MemoryBackend) renderCapped(ctx context.Context, ops []SessionOp) (string, error) { + bundle, err := b.renderer.Render(ctx, ops, b.Schema()) + if err != nil { + return "", err + } + if b.maxBundleBytes <= 0 || len(bundle) <= b.maxBundleBytes { + return bundle, nil + } + + for keep := len(ops) - 1; keep >= 0; keep-- { + if err := ctx.Err(); err != nil { + return "", err + } + truncated, err := b.renderer.Render(ctx, ops[:keep], b.Schema()) + if err != nil { + return "", err + } + if len(truncated) <= b.maxBundleBytes { + return truncated, nil + } + } + return "", nil +} + +func (b *MemoryBackend) runSweeper(ctx context.Context, done chan<- struct{}) { + defer close(done) + + interval := b.sessionTTL / 4 + if interval <= 0 { + interval = time.Second + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + b.sweepIdle() + b.enforceMaxSessions() + } + } +} + +func (b *MemoryBackend) sweepIdle() { + if b.sessionTTL <= 0 { + return + } + + cutoff := b.now().Add(-b.sessionTTL) + b.sessions.Range(func(key, value any) bool { + session := value.(*memSession) + session.mu.Lock() + expired := !session.lastUsed.After(cutoff) + session.mu.Unlock() + if expired { + b.sessions.Delete(key) + } + return true + }) +} + +func (b *MemoryBackend) enforceMaxSessions() { + if b.maxSessions <= 0 { + return + } + + snapshots := make([]sessionSnapshot, 0) + b.sessions.Range(func(key, value any) bool { + session := value.(*memSession) + session.mu.Lock() + snapshots = append(snapshots, sessionSnapshot{id: key.(string), lastUsed: session.lastUsed}) + session.mu.Unlock() + return true + }) + if len(snapshots) <= b.maxSessions { + return + } + + sort.Slice(snapshots, func(i, j int) bool { + if snapshots[i].lastUsed.Equal(snapshots[j].lastUsed) { + return snapshots[i].id < snapshots[j].id + } + return snapshots[i].lastUsed.Before(snapshots[j].lastUsed) + }) + for _, snapshot := range snapshots[:len(snapshots)-b.maxSessions] { + b.sessions.Delete(snapshot.id) + } +} + +func (b *MemoryBackend) clearSessions() { + b.sessions.Range(func(key, _ any) bool { + b.sessions.Delete(key) + return true + }) +} diff --git a/router/internal/codemode/storage/memory_backend_test.go b/router/internal/codemode/storage/memory_backend_test.go new file mode 100644 index 0000000000..662816ce64 --- /dev/null +++ b/router/internal/codemode/storage/memory_backend_test.go @@ -0,0 +1,332 @@ +package storage + +import ( + "context" + "fmt" + "sort" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +type testClock struct { + mu sync.Mutex + now time.Time +} + +func newTestClock() *testClock { + return &testClock{now: time.Unix(1_700_000_000, 0).UTC()} +} + +func (c *testClock) Now() time.Time { + c.mu.Lock() + defer c.mu.Unlock() + return c.now +} + +func (c *testClock) Advance(d time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.now = c.now.Add(d) +} + +func newTestBackend(t *testing.T, clock *testClock, renderer Renderer) *MemoryBackend { + t.Helper() + + if renderer == nil { + renderer = RendererFunc(func(ops []SessionOp) (string, error) { + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + return strings.Join(names, "\n"), nil + }) + } + + backend := NewMemoryBackend(MemoryConfig{ + SessionTTL: time.Hour, + MaxSessions: 100, + MaxBundleBytes: 1 << 20, + Renderer: renderer, + Now: clock.Now, + }) + require.NoError(t, backend.Start(context.Background())) + t.Cleanup(func() { + require.NoError(t, backend.Stop()) + }) + + return backend +} + +func TestMemoryBackendAppendGetOpBundleResetRoundTrip(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + + ops := []SessionOp{ + {Name: "get-user", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: "delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, + {Name: "get-user", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + } + + appended, err := backend.Append(ctx, "session-1", ops) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, + {Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + }, appended) + + gotQuery, ok, err := backend.GetOp(ctx, "session-1", "getUser") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, gotQuery) + + gotMutation, ok, err := backend.GetOp(ctx, "session-1", "op_delete") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, gotMutation) + + gotCollision, ok, err := backend.GetOp(ctx, "session-1", "getUser_2") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, gotCollision) + + bundle, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "getUser\nop_delete\ngetUser_2", bundle) + + require.NoError(t, backend.Reset(ctx, "session-1")) + gotAfterReset, ok, err := backend.GetOp(ctx, "session-1", "getUser") + require.NoError(t, err) + assert.Equal(t, false, ok) + assert.Equal(t, SessionOp{}, gotAfterReset) + + bundleAfterReset, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "", bundleAfterReset) +} + +func TestMemoryBackendSetSchemaClearsSessionsAndIncrementsSchemaVersion(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + initialVersion := backend.SchemaVersion() + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "get-user", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + schema := &ast.Document{} + + backend.SetSchema(schema) + + assert.Equal(t, initialVersion+1, backend.SchemaVersion()) + assert.Equal(t, schema, backend.Schema()) + + got, ok, err := backend.GetOp(ctx, "session-1", "getUser") + require.NoError(t, err) + assert.Equal(t, false, ok) + assert.Equal(t, SessionOp{}, got) + + backend.SetSchema(&ast.Document{}) + assert.Equal(t, initialVersion+2, backend.SchemaVersion()) +} + +func TestMemoryBackendTTLEvictionUsesInjectedClock(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := NewMemoryBackend(MemoryConfig{ + SessionTTL: time.Minute, + MaxSessions: 100, + MaxBundleBytes: 1 << 20, + Renderer: RendererFunc(func(ops []SessionOp) (string, error) { return "", nil }), + Now: clock.Now, + }) + + _, err := backend.Append(ctx, "idle", []SessionOp{{Name: "idle-op", Body: "query { idle }", Kind: OperationKindQuery}}) + require.NoError(t, err) + _, err = backend.Append(ctx, "fresh", []SessionOp{{Name: "fresh-op", Body: "query { fresh }", Kind: OperationKindQuery}}) + require.NoError(t, err) + clock.Advance(30 * time.Second) + _, ok, err := backend.GetOp(ctx, "fresh", "freshOp") + require.NoError(t, err) + assert.Equal(t, true, ok) + + clock.Advance(31 * time.Second) + backend.sweepIdle() + + _, idleOK, err := backend.GetOp(ctx, "idle", "idleOp") + require.NoError(t, err) + assert.Equal(t, false, idleOK) + + _, freshOK, err := backend.GetOp(ctx, "fresh", "freshOp") + require.NoError(t, err) + assert.Equal(t, true, freshOK) +} + +func TestMemoryBackendLRUEvictionAtMaxSessions(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := NewMemoryBackend(MemoryConfig{ + SessionTTL: time.Hour, + MaxSessions: 2, + MaxBundleBytes: 1 << 20, + Renderer: RendererFunc(func(ops []SessionOp) (string, error) { return "", nil }), + Now: clock.Now, + }) + + _, err := backend.Append(ctx, "session-a", []SessionOp{{Name: "a-op", Body: "query { a }", Kind: OperationKindQuery}}) + require.NoError(t, err) + clock.Advance(time.Second) + _, err = backend.Append(ctx, "session-b", []SessionOp{{Name: "b-op", Body: "query { b }", Kind: OperationKindQuery}}) + require.NoError(t, err) + clock.Advance(time.Second) + _, ok, err := backend.GetOp(ctx, "session-a", "aOp") + require.NoError(t, err) + assert.Equal(t, true, ok) + clock.Advance(time.Second) + + _, err = backend.Append(ctx, "session-c", []SessionOp{{Name: "c-op", Body: "query { c }", Kind: OperationKindQuery}}) + require.NoError(t, err) + + _, aOK, err := backend.GetOp(ctx, "session-a", "aOp") + require.NoError(t, err) + assert.Equal(t, true, aOK) + + _, bOK, err := backend.GetOp(ctx, "session-b", "bOp") + require.NoError(t, err) + assert.Equal(t, false, bOK) + + _, cOK, err := backend.GetOp(ctx, "session-c", "cOp") + require.NoError(t, err) + assert.Equal(t, true, cOK) +} + +func TestMemoryBackendConcurrentAppendIsRaceFreeAndSuffixesNames(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + + const goroutines = 32 + var wg sync.WaitGroup + errs := make(chan error, goroutines) + + for i := range goroutines { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, err := backend.Append(ctx, "shared", []SessionOp{{ + Name: "shared-op", + Body: fmt.Sprintf("query Shared%d { shared%d }", i, i), + Kind: OperationKindQuery, + Description: fmt.Sprintf("Shared %d", i), + }}) + errs <- err + }(i) + } + + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + + names := make([]string, 0, goroutines) + for i := range goroutines { + name := "sharedOp" + if i > 0 { + name = fmt.Sprintf("sharedOp_%d", i+1) + } + op, ok, err := backend.GetOp(ctx, "shared", name) + require.NoError(t, err) + assert.Equal(t, true, ok) + names = append(names, op.Name) + } + + sort.Strings(names) + want := make([]string, 0, goroutines) + for i := range goroutines { + name := "sharedOp" + if i > 0 { + name = fmt.Sprintf("sharedOp_%d", i+1) + } + want = append(want, name) + } + sort.Strings(want) + assert.Equal(t, want, names) +} + +func TestMemoryBackendBundleCacheInvalidatesOnAppend(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + var mu sync.Mutex + rendered := make([]string, 0, 3) + renderer := RendererFunc(func(ops []SessionOp) (string, error) { + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + bundle := strings.Join(names, ",") + mu.Lock() + rendered = append(rendered, bundle) + mu.Unlock() + return bundle, nil + }) + backend := newTestBackend(t, clock, renderer) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "one", Body: "query { one }", Kind: OperationKindQuery}}) + require.NoError(t, err) + first, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "one", first) + + second, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "one", second) + + _, err = backend.Append(ctx, "session-1", []SessionOp{{Name: "two", Body: "query { two }", Kind: OperationKindQuery}}) + require.NoError(t, err) + third, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "one,two", third) + + mu.Lock() + gotRendered := append([]string(nil), rendered...) + mu.Unlock() + assert.Equal(t, []string{"one", "one,two"}, gotRendered) +} + +func TestMemoryBackendBundleDropsWholeOpsAtMaxBundleBytes(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + renderer := RendererFunc(func(ops []SessionOp) (string, error) { + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + return strings.Join(names, "|"), nil + }) + backend := NewMemoryBackend(MemoryConfig{ + SessionTTL: time.Hour, + MaxSessions: 100, + MaxBundleBytes: len("one|two"), + Renderer: renderer, + Now: clock.Now, + }) + + _, err := backend.Append(ctx, "session-1", []SessionOp{ + {Name: "one", Body: "query { one }", Kind: OperationKindQuery}, + {Name: "two", Body: "query { two }", Kind: OperationKindQuery}, + {Name: "three", Body: "query { three }", Kind: OperationKindQuery}, + }) + require.NoError(t, err) + + bundle, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "one|two", bundle) +} diff --git a/router/internal/codemode/storage/naming.go b/router/internal/codemode/storage/naming.go new file mode 100644 index 0000000000..a3b91eadc9 --- /dev/null +++ b/router/internal/codemode/storage/naming.go @@ -0,0 +1,191 @@ +package storage + +import ( + "slices" + "strconv" + "strings" + "unicode" +) + +var reservedWords = map[string]struct{}{ + "abstract": {}, + "any": {}, + "as": {}, + "async": {}, + "await": {}, + "boolean": {}, + "break": {}, + "case": {}, + "catch": {}, + "class": {}, + "const": {}, + "constructor": {}, + "continue": {}, + "debugger": {}, + "declare": {}, + "default": {}, + "delete": {}, + "do": {}, + "else": {}, + "enum": {}, + "export": {}, + "extends": {}, + "false": {}, + "finally": {}, + "for": {}, + "from": {}, + "function": {}, + "get": {}, + "if": {}, + "implements": {}, + "import": {}, + "in": {}, + "infer": {}, + "instanceof": {}, + "interface": {}, + "is": {}, + "keyof": {}, + "let": {}, + "module": {}, + "namespace": {}, + "never": {}, + "new": {}, + "null": {}, + "number": {}, + "object": {}, + "of": {}, + "package": {}, + "private": {}, + "protected": {}, + "public": {}, + "readonly": {}, + "require": {}, + "return": {}, + "satisfies": {}, + "set": {}, + "static": {}, + "string": {}, + "super": {}, + "switch": {}, + "symbol": {}, + "this": {}, + "throw": {}, + "true": {}, + "try": {}, + "type": {}, + "typeof": {}, + "undefined": {}, + "unique": {}, + "unknown": {}, + "var": {}, + "void": {}, + "while": {}, + "with": {}, + "yield": {}, +} + +func NormalizeName(raw string) string { + // Idempotency: names produced by an earlier NormalizeName call (carrying our reserved-word + // or leading-digit prefixes) round-trip without re-splitting. + if rest, ok := strings.CutPrefix(raw, "op_"); ok { + if _, reserved := reservedWords[rest]; reserved && isLowerCamel(rest) { + return raw + } + } + if rest, ok := strings.CutPrefix(raw, "_"); ok { + if len(rest) > 0 && unicode.IsDigit(rune(rest[0])) && isIdentTail(rest) { + return raw + } + } + words := strings.FieldsFunc(raw, func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsDigit(r) + }) + words = slices.DeleteFunc(words, func(word string) bool { + return word == "" + }) + if len(words) == 0 { + return "operation" + } + + var builder strings.Builder + for i, word := range words { + if i == 0 { + builder.WriteString(lowerFirst(word)) + continue + } + builder.WriteString(upperFirst(word)) + } + + name := builder.String() + if name == "" { + name = "operation" + } + if first, _ := firstRune(name); unicode.IsDigit(first) { + name = "_" + name + } + if _, ok := reservedWords[name]; ok { + name = "op_" + name + } + return name +} + +func SuffixedName(base string, taken map[string]struct{}) string { + if _, ok := taken[base]; !ok { + return base + } + for i := 2; ; i++ { + name := base + "_" + strconv.Itoa(i) + if _, ok := taken[name]; !ok { + return name + } + } +} + +func lowerFirst(value string) string { + if value == "" { + return value + } + runes := []rune(value) + runes[0] = unicode.ToLower(runes[0]) + return string(runes) +} + +func upperFirst(value string) string { + if value == "" { + return value + } + runes := []rune(strings.ToLower(value)) + runes[0] = unicode.ToUpper(runes[0]) + return string(runes) +} + +func isLowerCamel(value string) bool { + if value == "" { + return false + } + for i, r := range value { + if i == 0 && !unicode.IsLower(r) { + return false + } + if !unicode.IsLetter(r) && !unicode.IsDigit(r) { + return false + } + } + return true +} + +func isIdentTail(value string) bool { + for _, r := range value { + if !unicode.IsLetter(r) && !unicode.IsDigit(r) { + return false + } + } + return true +} + +func firstRune(value string) (rune, bool) { + for _, r := range value { + return r, true + } + return 0, false +} diff --git a/router/internal/codemode/storage/naming_test.go b/router/internal/codemode/storage/naming_test.go new file mode 100644 index 0000000000..10215f9730 --- /dev/null +++ b/router/internal/codemode/storage/naming_test.go @@ -0,0 +1,84 @@ +package storage + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeName(t *testing.T) { + tests := []struct { + name string + raw string + want string + }{ + {name: "kebab case", raw: "get-user-by-id", want: "getUserById"}, + {name: "snake case", raw: "get_user_by_id", want: "getUserById"}, + {name: "space separated", raw: "Get User By ID", want: "getUserById"}, + {name: "mixed separators", raw: "get__user--by id", want: "getUserById"}, + {name: "already camel", raw: "getUserById", want: "getUserById"}, + {name: "leading digit", raw: "123foo", want: "_123foo"}, + {name: "leading digit with separators", raw: "123-foo-bar", want: "_123FooBar"}, + {name: "reserved word", raw: "delete", want: "op_delete"}, + {name: "reserved word after normalization", raw: "class", want: "op_class"}, + {name: "invalid punctuation", raw: "get$user#by%id", want: "getUserById"}, + {name: "empty input", raw: "", want: "operation"}, + {name: "only invalid input", raw: "$$$", want: "operation"}, + {name: "underscore output for reserved word is not rechecked", raw: "op-delete", want: "opDelete"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, NormalizeName(tt.raw)) + }) + } +} + +func TestSuffixedName(t *testing.T) { + tests := []struct { + name string + base string + taken map[string]struct{} + want string + }{ + { + name: "first use keeps base", + base: "getUser", + taken: map[string]struct{}{}, + want: "getUser", + }, + { + name: "first collision uses suffix two", + base: "getUser", + taken: map[string]struct{}{ + "getUser": {}, + }, + want: "getUser_2", + }, + { + name: "skips occupied suffixes", + base: "getUser", + taken: map[string]struct{}{ + "getUser": {}, + "getUser_2": {}, + "getUser_3": {}, + }, + want: "getUser_4", + }, + { + name: "gap is reused", + base: "getUser", + taken: map[string]struct{}{ + "getUser": {}, + "getUser_3": {}, + }, + want: "getUser_2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, SuffixedName(tt.base, tt.taken)) + }) + } +} diff --git a/router/internal/codemode/storage/redis_backend.go b/router/internal/codemode/storage/redis_backend.go new file mode 100644 index 0000000000..90e6e66883 --- /dev/null +++ b/router/internal/codemode/storage/redis_backend.go @@ -0,0 +1,362 @@ +package storage + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "go.uber.org/zap" +) + +const defaultRedisKeyPrefix = "cosmo_code_mode" + +var _ SessionStorage = (*RedisBackend)(nil) + +type RedisConfig struct { + Client redis.UniversalClient + KeyPrefix string + SessionTTL time.Duration + Renderer Renderer + Logger *zap.Logger + Now func() time.Time +} + +type RedisBackend struct { + client redis.UniversalClient + keyPrefix string + sessionTTL time.Duration + renderer Renderer + logger *zap.Logger + now func() time.Time + + schemaMu sync.RWMutex + schema *ast.Document + schemaVer atomic.Uint64 +} + +type redisOpEntry struct { + SessionOp + LastUsed time.Time `json:"last_used"` +} + +type redisBundleEntry struct { + Bundle string `json:"bundle"` + SchemaVer uint64 `json:"schema_ver"` + RenderedAt time.Time `json:"rendered_at"` +} + +func NewRedisBackend(cfg RedisConfig) (*RedisBackend, error) { + if cfg.Client == nil { + return nil, errors.New("code mode redis storage client is not configured") + } + if cfg.KeyPrefix == "" { + cfg.KeyPrefix = defaultRedisKeyPrefix + } + if cfg.SessionTTL <= 0 { + cfg.SessionTTL = defaultSessionTTL + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + if cfg.Now == nil { + cfg.Now = time.Now + } + + return &RedisBackend{ + client: cfg.Client, + keyPrefix: cfg.KeyPrefix, + sessionTTL: cfg.SessionTTL, + renderer: cfg.Renderer, + logger: cfg.Logger, + now: cfg.Now, + }, nil +} + +func (b *RedisBackend) Append(ctx context.Context, sessionID string, ops []SessionOp) ([]SessionOp, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if len(ops) == 0 { + return nil, nil + } + + backoff := 5 * time.Millisecond + var appended []SessionOp + for { + if err := ctx.Err(); err != nil { + return nil, err + } + + opsKey := b.opsKey(sessionID) + bundleKey := b.bundleKey(sessionID) + now := b.now() + err := b.client.Watch(ctx, func(tx *redis.Tx) error { + entries, err := b.readOps(ctx, tx, opsKey) + if err != nil { + return err + } + + taken := make(map[string]struct{}, len(entries)+len(ops)) + for _, entry := range entries { + taken[entry.Name] = struct{}{} + } + appended = make([]SessionOp, 0, len(ops)) + for _, op := range ops { + op.Name = SuffixedName(NormalizeName(op.Name), taken) + taken[op.Name] = struct{}{} + entries = append(entries, redisOpEntry{ + SessionOp: op, + LastUsed: now, + }) + appended = append(appended, op) + } + payload, err := json.Marshal(entries) + if err != nil { + return err + } + + _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, opsKey, payload, 0) + pipe.Expire(ctx, opsKey, b.sessionTTL) + pipe.Del(ctx, bundleKey) + return nil + }) + return err + }, opsKey) + if err == nil { + return appended, nil + } + + b.logger.Debug("retrying code mode redis append", + zap.String("session_id", sessionID), + zap.Error(err), + ) + if err := sleepWithContext(ctx, backoff); err != nil { + return nil, err + } + backoff *= 2 + if backoff > 100*time.Millisecond { + backoff = 100 * time.Millisecond + } + } +} + +func (b *RedisBackend) GetOp(ctx context.Context, sessionID string, name string) (SessionOp, bool, error) { + if err := ctx.Err(); err != nil { + return SessionOp{}, false, err + } + + opsKey := b.opsKey(sessionID) + entries, err := b.readOps(ctx, b.client, opsKey) + if err != nil { + return SessionOp{}, false, err + } + + for i, entry := range entries { + if entry.Name != name { + continue + } + entries[i].LastUsed = b.now() + b.touchOpBestEffort(ctx, opsKey, name) + return entry.SessionOp, true, nil + } + return SessionOp{}, false, nil +} + +func (b *RedisBackend) ListNames(ctx context.Context, sessionID string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + entries, err := b.readOps(ctx, b.client, b.opsKey(sessionID)) + if err != nil { + return nil, err + } + + names := make([]string, 0, len(entries)) + for _, entry := range entries { + names = append(names, entry.Name) + } + return names, nil +} + +func (b *RedisBackend) Bundle(ctx context.Context, sessionID string) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + bundleKey := b.bundleKey(sessionID) + cached, err := b.client.Get(ctx, bundleKey).Bytes() + if err == nil { + var entry redisBundleEntry + if err := json.Unmarshal(cached, &entry); err != nil { + return "", fmt.Errorf("decode code mode redis bundle: %w", err) + } + if entry.SchemaVer == b.SchemaVersion() { + return entry.Bundle, nil + } + } else if !errors.Is(err, redis.Nil) { + return "", err + } + + opsKey := b.opsKey(sessionID) + entries, err := b.readOps(ctx, b.client, opsKey) + if err != nil { + return "", err + } + if len(entries) == 0 { + if b.renderer == nil { + return "", errors.New("code mode storage renderer is not configured") + } + return b.renderer.Render(ctx, nil, b.Schema()) + } + if b.renderer == nil { + return "", errors.New("code mode storage renderer is not configured") + } + + ops := make([]SessionOp, 0, len(entries)) + for _, entry := range entries { + ops = append(ops, entry.SessionOp) + } + bundle, err := b.renderer.Render(ctx, ops, b.Schema()) + if err != nil { + return "", err + } + + payload, err := json.Marshal(redisBundleEntry{ + Bundle: bundle, + SchemaVer: b.SchemaVersion(), + RenderedAt: b.now(), + }) + if err != nil { + return "", err + } + if err := b.setWithTTL(ctx, bundleKey, payload); err != nil { + b.logger.Warn("failed to cache code mode redis bundle", + zap.String("session_id", sessionID), + zap.Error(err), + ) + } + return bundle, nil +} + +func (b *RedisBackend) Reset(ctx context.Context, sessionID string) error { + if err := ctx.Err(); err != nil { + return err + } + return b.client.Del(ctx, b.opsKey(sessionID), b.bundleKey(sessionID)).Err() +} + +func (b *RedisBackend) SetSchema(schema *ast.Document) { + b.schemaMu.Lock() + b.schema = schema + b.schemaMu.Unlock() + + b.schemaVer.Add(1) +} + +func (b *RedisBackend) Schema() *ast.Document { + b.schemaMu.RLock() + defer b.schemaMu.RUnlock() + return b.schema +} + +func (b *RedisBackend) SchemaVersion() uint64 { + return b.schemaVer.Load() +} + +func (b *RedisBackend) Start(context.Context) error { + return nil +} + +func (b *RedisBackend) Stop() error { + return nil +} + +func (b *RedisBackend) opsKey(sessionID string) string { + return fmt.Sprintf("%s:s:%d:%s:ops", b.keyPrefix, b.SchemaVersion(), sessionID) +} + +func (b *RedisBackend) bundleKey(sessionID string) string { + return fmt.Sprintf("%s:s:%d:%s:bundle", b.keyPrefix, b.SchemaVersion(), sessionID) +} + +type redisStringGetter interface { + Get(context.Context, string) *redis.StringCmd +} + +func (b *RedisBackend) readOps(ctx context.Context, getter redisStringGetter, key string) ([]redisOpEntry, error) { + raw, err := getter.Get(ctx, key).Bytes() + if errors.Is(err, redis.Nil) { + return nil, nil + } + if err != nil { + return nil, err + } + + var entries []redisOpEntry + if err := json.Unmarshal(raw, &entries); err != nil { + return nil, fmt.Errorf("decode code mode redis ops: %w", err) + } + return entries, nil +} + +func (b *RedisBackend) touchOpBestEffort(ctx context.Context, key string, name string) { + err := b.client.Watch(ctx, func(tx *redis.Tx) error { + entries, err := b.readOps(ctx, tx, key) + if err != nil { + return err + } + + found := false + for i := range entries { + if entries[i].Name == name { + entries[i].LastUsed = b.now() + found = true + break + } + } + if !found { + return nil + } + + payload, err := json.Marshal(entries) + if err != nil { + return err + } + _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, key, payload, 0) + pipe.Expire(ctx, key, b.sessionTTL) + return nil + }) + return err + }, key) + if err != nil && !errors.Is(err, redis.TxFailedErr) { + b.logger.Warn("failed to update code mode redis op last_used", zap.Error(err)) + } +} + +func (b *RedisBackend) setWithTTL(ctx context.Context, key string, value []byte) error { + if err := b.client.Set(ctx, key, value, 0).Err(); err != nil { + return err + } + return b.client.Expire(ctx, key, b.sessionTTL).Err() +} + +func sleepWithContext(ctx context.Context, duration time.Duration) error { + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/router/internal/codemode/storage/redis_backend_test.go b/router/internal/codemode/storage/redis_backend_test.go new file mode 100644 index 0000000000..3bb736353c --- /dev/null +++ b/router/internal/codemode/storage/redis_backend_test.go @@ -0,0 +1,264 @@ +package storage + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + miniredisserver "github.com/alicebob/miniredis/v2/server" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +type testRedisRenderer func(context.Context, []SessionOp, *ast.Document) (string, error) + +func (f testRedisRenderer) Render(ctx context.Context, ops []SessionOp, schema *ast.Document) (string, error) { + return f(ctx, ops, schema) +} + +func newTestRedisBackend(t *testing.T, renderer Renderer, ttl time.Duration) (*RedisBackend, *miniredis.Miniredis, *redis.Client) { + t.Helper() + + if renderer == nil { + renderer = testRedisRenderer(func(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + return strings.Join(names, "\n"), nil + }) + } + if ttl == 0 { + ttl = time.Hour + } + + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { + require.NoError(t, client.Close()) + }) + + backend, err := NewRedisBackend(RedisConfig{ + Client: client, + KeyPrefix: "test_code_mode", + SessionTTL: ttl, + Renderer: renderer, + Now: func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }, + }) + require.NoError(t, err) + require.NoError(t, backend.Start(context.Background())) + t.Cleanup(func() { + require.NoError(t, backend.Stop()) + }) + + return backend, mr, client +} + +func TestRedisBackendAppendGetOpRoundTrip(t *testing.T) { + ctx := context.Background() + backend, _, _ := newTestRedisBackend(t, nil, time.Hour) + + ops := []SessionOp{ + {Name: "get-user", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, + {Name: "get-user", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + } + appended, err := backend.Append(ctx, "session-1", ops) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, + {Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + }, appended) + + gotQuery, ok, err := backend.GetOp(ctx, "session-1", "getUser") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, gotQuery) + + gotCollision, ok, err := backend.GetOp(ctx, "session-1", "getUser_2") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, gotCollision) + + gotMissing, ok, err := backend.GetOp(ctx, "session-1", "missing") + require.NoError(t, err) + assert.Equal(t, false, ok) + assert.Equal(t, SessionOp{}, gotMissing) +} + +func TestRedisBackendBundleRendersAndReadsFromCache(t *testing.T) { + ctx := context.Background() + var renders atomic.Int64 + backend, mr, _ := newTestRedisBackend(t, testRedisRenderer(func(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { + renders.Add(1) + return fmt.Sprintf("render-%d:%s", renders.Load(), ops[0].Name), nil + }), time.Hour) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + + first, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "render-1:getUser", first) + assert.Equal(t, true, mr.Exists(backend.bundleKey("session-1"))) + + second, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "render-1:getUser", second) + assert.Equal(t, int64(1), renders.Load()) +} + +func TestRedisBackendResetClearsOpsAndBundleKeys(t *testing.T) { + ctx := context.Background() + backend, mr, _ := newTestRedisBackend(t, nil, time.Hour) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + _, err = backend.Bundle(ctx, "session-1") + require.NoError(t, err) + opsKey := backend.opsKey("session-1") + bundleKey := backend.bundleKey("session-1") + assert.Equal(t, true, mr.Exists(opsKey)) + assert.Equal(t, true, mr.Exists(bundleKey)) + + require.NoError(t, backend.Reset(ctx, "session-1")) + + assert.Equal(t, false, mr.Exists(opsKey)) + assert.Equal(t, false, mr.Exists(bundleKey)) +} + +func TestRedisBackendSetSchemaRotatesKeysAndKeepsOldKeysUntilTTL(t *testing.T) { + ctx := context.Background() + schemaA := &ast.Document{RootNodes: []ast.Node{{Kind: ast.NodeKindSchemaDefinition}}} + schemaB := &ast.Document{RootNodes: []ast.Node{{Kind: ast.NodeKindObjectTypeDefinition}}} + backend, mr, _ := newTestRedisBackend(t, testRedisRenderer(func(_ context.Context, _ []SessionOp, schema *ast.Document) (string, error) { + return fmt.Sprintf("schema-kind-%d", schema.RootNodes[0].Kind), nil + }), time.Hour) + backend.SetSchema(schemaA) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + oldOpsKey := backend.opsKey("session-1") + oldBundleKey := backend.bundleKey("session-1") + first, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("schema-kind-%d", schemaA.RootNodes[0].Kind), first) + assert.Equal(t, true, mr.Exists(oldOpsKey)) + assert.Equal(t, true, mr.Exists(oldBundleKey)) + + oldVersion := backend.SchemaVersion() + backend.SetSchema(schemaB) + + assert.Equal(t, oldVersion+1, backend.SchemaVersion()) + assert.Equal(t, schemaB, backend.Schema()) + assert.Equal(t, true, mr.Exists(oldOpsKey)) + assert.Equal(t, true, mr.Exists(oldBundleKey)) + + _, err = backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + second, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("schema-kind-%d", schemaB.RootNodes[0].Kind), second) + assert.Equal(t, true, mr.Exists(backend.opsKey("session-1"))) + assert.Equal(t, true, mr.Exists(backend.bundleKey("session-1"))) +} + +func TestRedisBackendConcurrentAppendRetriesWatchConflicts(t *testing.T) { + ctx := context.Background() + backend, mr, _ := newTestRedisBackend(t, nil, time.Hour) + const goroutines = 12 + const opsPerGoroutine = 8 + + var wg sync.WaitGroup + errs := make(chan error, goroutines) + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(worker int) { + defer wg.Done() + ops := make([]SessionOp, 0, opsPerGoroutine) + for j := 0; j < opsPerGoroutine; j++ { + ops = append(ops, SessionOp{Name: fmt.Sprintf("op_%02d_%02d", worker, j), Body: "query { ok }", Kind: OperationKindQuery}) + } + _, err := backend.Append(ctx, "session-1", ops) + errs <- err + }(i) + } + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } + + raw, err := mr.Get(backend.opsKey("session-1")) + require.NoError(t, err) + var entries []redisOpEntry + require.NoError(t, json.Unmarshal([]byte(raw), &entries)) + assert.Equal(t, goroutines*opsPerGoroutine, len(entries)) +} + +func TestRedisBackendAppendAbandonsOnContextDone(t *testing.T) { + backend, mr, _ := newTestRedisBackend(t, nil, time.Hour) + mr.SetError("LOADING Redis is loading the dataset in memory") + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + + require.Error(t, err) + assert.Equal(t, true, errors.Is(err, context.DeadlineExceeded)) +} + +func TestRedisBackendExpiresKeysOnWrites(t *testing.T) { + ctx := context.Background() + backend, mr, _ := newTestRedisBackend(t, nil, 10*time.Second) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + opsKey := backend.opsKey("session-1") + assert.Equal(t, 10*time.Second, mr.TTL(opsKey)) + + _, err = backend.Bundle(ctx, "session-1") + require.NoError(t, err) + bundleKey := backend.bundleKey("session-1") + assert.Equal(t, 10*time.Second, mr.TTL(bundleKey)) + + mr.FastForward(11 * time.Second) + assert.Equal(t, false, mr.Exists(opsKey)) + assert.Equal(t, false, mr.Exists(bundleKey)) +} + +func TestRedisBackendBundleWriteBackIsBestEffort(t *testing.T) { + ctx := context.Background() + backend, mr, _ := newTestRedisBackend(t, testRedisRenderer(func(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { + return "rendered:" + ops[0].Name, nil + }), time.Hour) + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + + mr.Server().SetPreHook(func(c *miniredisserver.Peer, cmd string, _ ...string) bool { + if strings.EqualFold(cmd, "set") { + c.WriteError("ERR forced set failure") + return true + } + return false + }) + t.Cleanup(func() { + mr.Server().SetPreHook(nil) + }) + + bundle, err := backend.Bundle(ctx, "session-1") + + require.NoError(t, err) + assert.Equal(t, "rendered:getUser", bundle) + assert.Equal(t, false, mr.Exists(backend.bundleKey("session-1"))) +} diff --git a/router/internal/codemode/storage/storage.go b/router/internal/codemode/storage/storage.go new file mode 100644 index 0000000000..fc847a7acb --- /dev/null +++ b/router/internal/codemode/storage/storage.go @@ -0,0 +1,29 @@ +package storage + +import ( + "context" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +type SessionStorage interface { + Append(ctx context.Context, sessionID string, ops []SessionOp) ([]SessionOp, error) + GetOp(ctx context.Context, sessionID string, name string) (SessionOp, bool, error) + ListNames(ctx context.Context, sessionID string) ([]string, error) + Bundle(ctx context.Context, sessionID string) (string, error) + Reset(ctx context.Context, sessionID string) error + SetSchema(*ast.Document) + Schema() *ast.Document + Start(ctx context.Context) error + Stop() error +} + +type Renderer interface { + Render(ctx context.Context, ops []SessionOp, schema *ast.Document) (string, error) +} + +type RendererFunc func([]SessionOp) (string, error) + +func (f RendererFunc) Render(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { + return f(ops) +} diff --git a/router/internal/codemode/storage/types.go b/router/internal/codemode/storage/types.go new file mode 100644 index 0000000000..ba3f1c7df2 --- /dev/null +++ b/router/internal/codemode/storage/types.go @@ -0,0 +1,15 @@ +package storage + +type OperationKind string + +const ( + OperationKindQuery OperationKind = "Query" + OperationKindMutation OperationKind = "Mutation" +) + +type SessionOp struct { + Name string + Body string + Kind OperationKind + Description string +} diff --git a/router/internal/codemode/tsgen/bundle_test.go b/router/internal/codemode/tsgen/bundle_test.go new file mode 100644 index 0000000000..462fd3f0d0 --- /dev/null +++ b/router/internal/codemode/tsgen/bundle_test.go @@ -0,0 +1,138 @@ +package tsgen + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" +) + +func TestRenderBundleEmptyOps(t *testing.T) { + got, err := RenderBundle(nil, testSchema(t), 0) + require.NoError(t, err) + + want := "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\n" + + "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n" + + "\n" + + "declare const tools: {};\n" + + "\n" + + "declare function notNull(value: T | null | undefined, message?: string): T;\n" + + "declare function compact(value: T): T;" + + assert.Equal(t, want, got) +} + +func TestRenderBundleThreeOpsNoTruncation(t *testing.T) { + ops := []storage.SessionOp{ + {Name: "health", Body: `query Health { health }`, Kind: storage.OperationKindQuery, Description: "Checks router health."}, + {Name: "viewer", Body: `query Viewer { viewer { id name } }`, Kind: storage.OperationKindQuery, Description: "Fetches viewer."}, + {Name: "renameUser", Body: `mutation RenameUser($id: ID!, $name: String!) { renameUser(id: $id, name: $name) { id } }`, Kind: storage.OperationKindMutation, Description: "Renames a user."}, + } + + got, err := RenderBundle(ops, testSchema(t), 0) + require.NoError(t, err) + + want := "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\n" + + "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n" + + "\n" + + "declare const tools: {\n" + + " /** Checks router health. */\n" + + " health(): R<{ health: string }>;\n" + + "\n" + + " /** Fetches viewer. */\n" + + " viewer(): R<{ viewer: { id: string; name: string } | null }>;\n" + + "\n" + + " /** Renames a user. */\n" + + " renameUser(vars: { id: string; name: string }): R<{ renameUser: { id: string } }>;\n" + + "};\n" + + "\n" + + "declare function notNull(value: T | null | undefined, message?: string): T;\n" + + "declare function compact(value: T): T;" + + assert.Equal(t, want, got) +} + +func TestRenderBundleTruncatesWholeOpsFromEnd(t *testing.T) { + ops := []storage.SessionOp{ + {Name: "health", Body: `query Health { health }`, Kind: storage.OperationKindQuery, Description: "Checks router health."}, + {Name: "viewer", Body: `query Viewer { viewer { id name } }`, Kind: storage.OperationKindQuery, Description: "Fetches viewer."}, + {Name: "renameUser", Body: `mutation RenameUser($id: ID!, $name: String!) { renameUser(id: $id, name: $name) { id } }`, Kind: storage.OperationKindMutation, Description: "Renames a user."}, + } + fullWithTwo := "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\n" + + "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n" + + "\n" + + "declare const tools: {\n" + + " /** Checks router health. */\n" + + " health(): R<{ health: string }>;\n" + + "\n" + + " /** Fetches viewer. */\n" + + " viewer(): R<{ viewer: { id: string; name: string } | null }>;\n" + + "};\n" + + "\n" + + "declare function notNull(value: T | null | undefined, message?: string): T;\n" + + "declare function compact(value: T): T;\n" + + "// truncated: 1 ops omitted" + + got, err := RenderBundle(ops, testSchema(t), len(fullWithTwo)) + require.NoError(t, err) + + assert.Equal(t, fullWithTwo, got) +} + +func TestRenderBundleErrorsWhenPreludeCannotFit(t *testing.T) { + _, err := RenderBundle(nil, testSchema(t), 12) + require.Error(t, err) +} + +func TestRenderBundleRoundTripsAbstractField(t *testing.T) { + ops := []storage.SessionOp{ + { + Name: "petsList", + Body: `query PetsList { pets { __typename ... on Cat { name } ... on Dog { bark } } }`, + Kind: storage.OperationKindQuery, + Description: "Lists pets.", + }, + } + + got, err := RenderBundle(ops, testSchema(t), 0) + require.NoError(t, err) + + want := "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\n" + + "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n" + + "\n" + + "declare const tools: {\n" + + " /** Lists pets. */\n" + + " petsList(): R<{ pets: ({ __typename: \"Cat\"; name: string } | { __typename: \"Dog\"; bark: string } | { __typename: \"Mouse\" })[] }>;\n" + + "};\n" + + "\n" + + "declare function notNull(value: T | null | undefined, message?: string): T;\n" + + "declare function compact(value: T): T;" + + assert.Equal(t, want, got) +} + +func TestNewOpsFragmentReturnsOnlySignatures(t *testing.T) { + ops := []storage.SessionOp{ + {Name: "health", Body: `query Health { health }`, Kind: storage.OperationKindQuery, Description: "Checks router health."}, + {Name: "viewer", Body: `query Viewer { viewer { id } }`, Kind: storage.OperationKindQuery, Description: "Fetches viewer."}, + {Name: "animal", Body: `query Animal { animal { id } }`, Kind: storage.OperationKindQuery, Description: "Fetches animal."}, + } + + got, err := NewOpsFragment(ops, testSchema(t)) + require.NoError(t, err) + + want := "/** Checks router health. */\n" + + "health(): R<{ health: string }>;\n" + + "\n" + + "/** Fetches viewer. */\n" + + "viewer(): R<{ viewer: { id: string } | null }>;\n" + + "\n" + + "/** Fetches animal. */\n" + + "animal(): R<{ animal: { id: string } | null }>;" + + assert.Equal(t, want, got) + assert.False(t, strings.Contains(got, "declare const tools")) + assert.False(t, strings.Contains(got, "type R")) +} diff --git a/router/internal/codemode/tsgen/graphql.go b/router/internal/codemode/tsgen/graphql.go new file mode 100644 index 0000000000..352f0d21a8 --- /dev/null +++ b/router/internal/codemode/tsgen/graphql.go @@ -0,0 +1,674 @@ +package tsgen + +import ( + "fmt" + "strconv" + "strings" + + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" +) + +type operationRenderer struct { + schema *ast.Document +} + +func (r operationRenderer) renderOperation(op storage.SessionOp) (string, error) { + if r.schema == nil { + return "", fmt.Errorf("render op %q: schema is nil", op.Name) + } + + opDoc, report := astparser.ParseGraphqlDocumentString(op.Body) + if report.HasErrors() { + return "", fmt.Errorf("render op %q: parse GraphQL operation: %s", op.Name, report.Error()) + } + + opRef, err := singleOperationRef(&opDoc) + if err != nil { + return "", fmt.Errorf("render op %q: %w", op.Name, err) + } + + varsType, varsOptional, err := r.variablesType(&opDoc, opRef) + if err != nil { + return "", fmt.Errorf("render op %q: %w", op.Name, err) + } + + outputType, err := r.outputType(&opDoc, opRef) + if err != nil { + return "", fmt.Errorf("render op %q: %w", op.Name, err) + } + + return writeFieldSignature(op.Description, op.Name, varsType, outputType, varsOptional), nil +} + +func singleOperationRef(doc *ast.Document) (int, error) { + var refs []int + for _, node := range doc.RootNodes { + if node.Kind == ast.NodeKindOperationDefinition { + refs = append(refs, node.Ref) + } + } + if len(refs) == 0 { + return 0, fmt.Errorf("operation document contains no operation definition") + } + if len(refs) > 1 { + return 0, fmt.Errorf("operation document contains %d operation definitions", len(refs)) + } + return refs[0], nil +} + +func (r operationRenderer) variablesType(opDoc *ast.Document, opRef int) (string, bool, error) { + op := opDoc.OperationDefinitions[opRef] + if !op.HasVariableDefinitions || len(op.VariableDefinitions.Refs) == 0 { + return "{}", true, nil + } + + fields := make([]tsProperty, 0, len(op.VariableDefinitions.Refs)) + varsOptional := true + for _, varRef := range op.VariableDefinitions.Refs { + name := opDoc.VariableDefinitionNameString(varRef) + typeRef := opDoc.VariableDefinitionType(varRef) + required := opDoc.Types[typeRef].TypeKind == ast.TypeKindNonNull + + typ, nullable, err := r.inputType(opDoc, typeRef) + if err != nil { + return "", false, err + } + if nullable { + typ = writeNullable(typ) + } else { + varsOptional = false + } + + fields = append(fields, tsProperty{name: name, typ: typ, optional: !required}) + } + + return writeInlineObject(fields), varsOptional, nil +} + +func (r operationRenderer) inputType(doc *ast.Document, typeRef int) (string, bool, error) { + gqlType := doc.Types[typeRef] + switch gqlType.TypeKind { + case ast.TypeKindNonNull: + typ, _, err := r.inputType(doc, gqlType.OfType) + return typ, false, err + case ast.TypeKindList: + item, itemNullable, err := r.inputType(doc, gqlType.OfType) + if err != nil { + return "", false, err + } + if itemNullable { + item = writeNullable(item) + } + return writeArray(item), true, nil + case ast.TypeKindNamed: + typ, err := r.inputNamedType(doc.TypeNameString(typeRef)) + return typ, true, err + default: + return "", false, fmt.Errorf("unsupported GraphQL input type kind %s", gqlType.TypeKind.String()) + } +} + +func (r operationRenderer) inputNamedType(typeName string) (string, error) { + switch typeName { + case "ID", "String": + return "string", nil + case "Int", "Float": + return "number", nil + case "Boolean": + return "boolean", nil + } + + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return "", fmt.Errorf("missing schema type %q", typeName) + } + + switch node.Kind { + case ast.NodeKindEnumTypeDefinition: + values := r.enumValues(node.Ref) + return writeStringLiteralUnion(values), nil + case ast.NodeKindInputObjectTypeDefinition: + return r.inputObjectType(node.Ref) + case ast.NodeKindScalarTypeDefinition: + return "unknown", nil + default: + return "unknown", nil + } +} + +func (r operationRenderer) enumValues(enumRef int) []string { + def := r.schema.EnumTypeDefinitions[enumRef] + values := make([]string, 0, len(def.EnumValuesDefinition.Refs)) + for _, valueRef := range def.EnumValuesDefinition.Refs { + values = append(values, r.schema.EnumValueDefinitionNameString(valueRef)) + } + return values +} + +func (r operationRenderer) inputObjectType(inputObjectRef int) (string, error) { + def := r.schema.InputObjectTypeDefinitions[inputObjectRef] + fields := make([]tsProperty, 0, len(def.InputFieldsDefinition.Refs)) + for _, fieldRef := range def.InputFieldsDefinition.Refs { + name := r.schema.InputValueDefinitionNameString(fieldRef) + typeRef := r.schema.InputValueDefinitionType(fieldRef) + required := r.schema.Types[typeRef].TypeKind == ast.TypeKindNonNull + + typ, nullable, err := r.inputType(r.schema, typeRef) + if err != nil { + return "", err + } + if nullable { + typ = writeNullable(typ) + } + + fields = append(fields, tsProperty{name: name, typ: typ, optional: !required}) + } + + return writeInlineObject(fields), nil +} + +func (r operationRenderer) outputType(opDoc *ast.Document, opRef int) (string, error) { + op := opDoc.OperationDefinitions[opRef] + rootNode, err := r.rootOperationNode(op.OperationType) + if err != nil { + return "", err + } + + return r.selectionSetType(opDoc, op.SelectionSet, rootNode) +} + +func (r operationRenderer) rootOperationNode(operationType ast.OperationType) (ast.Node, error) { + var typeName []byte + switch operationType { + case ast.OperationTypeQuery: + typeName = r.schema.Index.QueryTypeName + if len(typeName) == 0 { + typeName = []byte("Query") + } + case ast.OperationTypeMutation: + typeName = r.schema.Index.MutationTypeName + if len(typeName) == 0 { + typeName = []byte("Mutation") + } + case ast.OperationTypeSubscription: + typeName = r.schema.Index.SubscriptionTypeName + if len(typeName) == 0 { + typeName = []byte("Subscription") + } + default: + return ast.Node{}, fmt.Errorf("unsupported operation type %s", operationType.Name()) + } + + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes(typeName) + if !exists { + return ast.Node{}, fmt.Errorf("missing schema root type %q", string(typeName)) + } + return node, nil +} + +func (r operationRenderer) selectionSetType(opDoc *ast.Document, selectionSetRef int, parent ast.Node) (string, error) { + selections := opDoc.SelectionSets[selectionSetRef] + fields := make([]tsProperty, 0, len(selections.SelectionRefs)) + + for _, selectionRef := range selections.SelectionRefs { + selection := opDoc.Selections[selectionRef] + switch selection.Kind { + case ast.SelectionKindField: + field, err := r.fieldProperty(opDoc, selection.Ref, parent) + if err != nil { + return "", err + } + fields = append(fields, field) + case ast.SelectionKindInlineFragment: + inlineFields, err := r.inlineFragmentProperties(opDoc, selection.Ref, parent) + if err != nil { + return "", err + } + fields = append(fields, inlineFields...) + case ast.SelectionKindFragmentSpread: + fragmentFields, err := r.fragmentSpreadProperties(opDoc, selection.Ref, parent) + if err != nil { + return "", err + } + fields = append(fields, fragmentFields...) + default: + return "", fmt.Errorf("unsupported selection kind %s", selection.Kind.String()) + } + } + + return writeInlineObject(fields), nil +} + +func (r operationRenderer) fieldProperty(opDoc *ast.Document, fieldRef int, parent ast.Node) (tsProperty, error) { + name := opDoc.FieldNameString(fieldRef) + propName := opDoc.FieldAliasOrNameString(fieldRef) + + if name == "__typename" { + return tsProperty{name: propName, typ: "string"}, nil + } + + fieldDefRef, exists := r.schema.NodeFieldDefinitionByName(parent, []byte(name)) + if !exists { + return tsProperty{}, fmt.Errorf("missing field %q on schema type %q", name, parent.NameString(r.schema)) + } + + selectionSetRef := -1 + if opDoc.Fields[fieldRef].HasSelections { + selectionSetRef = opDoc.Fields[fieldRef].SelectionSet + } + + typeRef := r.schema.FieldDefinitionType(fieldDefRef) + typ, nullable, err := r.outputGraphQLType(opDoc, typeRef, selectionSetRef) + if err != nil { + return tsProperty{}, err + } + if nullable { + typ = writeNullable(typ) + } + + return tsProperty{name: propName, typ: typ}, nil +} + +func (r operationRenderer) outputGraphQLType(opDoc *ast.Document, typeRef int, selectionSetRef int) (string, bool, error) { + gqlType := r.schema.Types[typeRef] + switch gqlType.TypeKind { + case ast.TypeKindNonNull: + typ, _, err := r.outputGraphQLType(opDoc, gqlType.OfType, selectionSetRef) + return typ, false, err + case ast.TypeKindList: + item, itemNullable, err := r.outputGraphQLType(opDoc, gqlType.OfType, selectionSetRef) + if err != nil { + return "", false, err + } + if itemNullable { + item = writeNullable(item) + } + return writeArray(item), true, nil + case ast.TypeKindNamed: + typ, err := r.outputNamedType(opDoc, r.schema.TypeNameString(typeRef), selectionSetRef) + return typ, true, err + default: + return "", false, fmt.Errorf("unsupported GraphQL output type kind %s", gqlType.TypeKind.String()) + } +} + +func (r operationRenderer) outputNamedType(opDoc *ast.Document, typeName string, selectionSetRef int) (string, error) { + switch typeName { + case "ID", "String": + return "string", nil + case "Int", "Float": + return "number", nil + case "Boolean": + return "boolean", nil + } + + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return "", fmt.Errorf("missing schema type %q", typeName) + } + + switch node.Kind { + case ast.NodeKindEnumTypeDefinition: + return writeStringLiteralUnion(r.enumValues(node.Ref)), nil + case ast.NodeKindObjectTypeDefinition: + if selectionSetRef < 0 { + return "", fmt.Errorf("object type %q requires a selection set", typeName) + } + return r.selectionSetType(opDoc, selectionSetRef, node) + case ast.NodeKindInterfaceTypeDefinition, ast.NodeKindUnionTypeDefinition: + if selectionSetRef < 0 { + return "", fmt.Errorf("abstract type %q requires a selection set", typeName) + } + return r.abstractFieldType(opDoc, selectionSetRef, node) + case ast.NodeKindScalarTypeDefinition: + return "unknown", nil + default: + return "unknown", nil + } +} + +func (r operationRenderer) inlineFragmentProperties(opDoc *ast.Document, inlineRef int, parent ast.Node) ([]tsProperty, error) { + fragment := opDoc.InlineFragments[inlineRef] + fragmentParent := parent + if opDoc.InlineFragmentHasTypeCondition(inlineRef) { + typeName := opDoc.InlineFragmentTypeConditionNameString(inlineRef) + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return nil, fmt.Errorf("missing schema type %q", typeName) + } + fragmentParent = node + } + + typ, err := r.selectionSetType(opDoc, fragment.SelectionSet, fragmentParent) + if err != nil { + return nil, err + } + + return propertiesFromInlineObject(typ), nil +} + +func (r operationRenderer) fragmentSpreadProperties(opDoc *ast.Document, spreadRef int, parent ast.Node) ([]tsProperty, error) { + fragmentName := opDoc.FragmentSpreadNameBytes(spreadRef) + fragmentRef, exists := opDoc.FragmentDefinitionRef(fragmentName) + if !exists { + return nil, fmt.Errorf("missing fragment %q", string(fragmentName)) + } + + fragment := opDoc.FragmentDefinitions[fragmentRef] + fragmentParent := parent + typeName := opDoc.ResolveTypeNameString(fragment.TypeCondition.Type) + if typeName != "" { + node, nodeExists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !nodeExists { + return nil, fmt.Errorf("missing schema type %q", typeName) + } + fragmentParent = node + } + + typ, err := r.selectionSetType(opDoc, fragment.SelectionSet, fragmentParent) + if err != nil { + return nil, err + } + + return propertiesFromInlineObject(typ), nil +} + +func propertiesFromInlineObject(typ string) []tsProperty { + if typ == "{}" { + return nil + } + + inner := typ[2 : len(typ)-2] + parts := splitInlineObjectFields(inner) + props := make([]tsProperty, 0, len(parts)) + for _, part := range parts { + nameAndType := splitProperty(part) + if nameAndType.name == "" { + continue + } + props = append(props, nameAndType) + } + + return props +} + +func splitInlineObjectFields(inner string) []string { + var parts []string + start := 0 + depth := 0 + for i := 0; i < len(inner); i++ { + switch inner[i] { + case '{': + depth++ + case '}': + depth-- + case ';': + if depth == 0 && i+1 < len(inner) && inner[i+1] == ' ' { + parts = append(parts, inner[start:i]) + start = i + 2 + } + } + } + parts = append(parts, inner[start:]) + return parts +} + +func splitProperty(part string) tsProperty { + for i := 0; i < len(part); i++ { + if part[i] != ':' { + continue + } + optional := i > 0 && part[i-1] == '?' + nameEnd := i + if optional { + nameEnd-- + } + return tsProperty{name: part[:nameEnd], typ: part[i+2:], optional: optional} + } + return tsProperty{} +} + +// abstractSelectionSet describes a fragment to be applied to the matching +// branches when lowering an abstract-typed field. `condition` is the schema +// node referenced by the fragment's type condition (or the parent abstract +// node itself for inline fragments without a type condition). +type abstractSelectionSet struct { + condition ast.Node + selectionSetRef int +} + +// abstractFieldType lowers a selection set on an interface- or union-typed +// field into a flat discriminated union of branches, one per concrete +// implementor. +func (r operationRenderer) abstractFieldType(opDoc *ast.Document, selectionSetRef int, parent ast.Node) (string, error) { + parentName := parent.NameString(r.schema) + possibleNames := r.possibleTypeNames(parent) + if len(possibleNames) == 0 { + return "", fmt.Errorf("abstract type %q has no possible types", parentName) + } + possibleSet := make(map[string]struct{}, len(possibleNames)) + for _, name := range possibleNames { + possibleSet[name] = struct{}{} + } + + selections := opDoc.SelectionSets[selectionSetRef] + if len(selections.SelectionRefs) == 0 { + return "", fmt.Errorf("abstract type %q requires at least one selection", parentName) + } + + // Bucket the selections. + var bareFieldRefs []int // Field selections defined on the abstract parent itself + var typenameSelected bool // unaliased __typename selected directly + var fragments []abstractSelectionSet + + for _, selRef := range selections.SelectionRefs { + sel := opDoc.Selections[selRef] + switch sel.Kind { + case ast.SelectionKindField: + fieldRef := sel.Ref + fieldName := opDoc.FieldNameString(fieldRef) + if fieldName == "__typename" { + if opDoc.FieldAliasOrNameString(fieldRef) == "__typename" { + typenameSelected = true + } else { + // aliased __typename: render through normal field path on each branch + bareFieldRefs = append(bareFieldRefs, fieldRef) + } + continue + } + // Non-typename bare field is only valid on interface parents and must + // be defined on the parent interface. + if parent.Kind != ast.NodeKindInterfaceTypeDefinition { + return "", fmt.Errorf("field %q is not valid on union type %q", fieldName, parentName) + } + if _, exists := r.schema.NodeFieldDefinitionByName(parent, []byte(fieldName)); !exists { + return "", fmt.Errorf("missing field %q on interface %q", fieldName, parentName) + } + bareFieldRefs = append(bareFieldRefs, fieldRef) + case ast.SelectionKindInlineFragment: + inlineRef := sel.Ref + inline := opDoc.InlineFragments[inlineRef] + condition := parent + if opDoc.InlineFragmentHasTypeCondition(inlineRef) { + typeName := opDoc.InlineFragmentTypeConditionNameString(inlineRef) + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return "", fmt.Errorf("missing schema type %q", typeName) + } + condition = node + } + if err := r.checkAbstractFragmentCondition(condition, possibleSet, parentName); err != nil { + return "", err + } + fragments = append(fragments, abstractSelectionSet{ + condition: condition, + selectionSetRef: inline.SelectionSet, + }) + case ast.SelectionKindFragmentSpread: + spreadRef := sel.Ref + fragmentName := opDoc.FragmentSpreadNameBytes(spreadRef) + fragRef, exists := opDoc.FragmentDefinitionRef(fragmentName) + if !exists { + return "", fmt.Errorf("missing fragment %q", string(fragmentName)) + } + fragment := opDoc.FragmentDefinitions[fragRef] + typeName := opDoc.ResolveTypeNameString(fragment.TypeCondition.Type) + if typeName == "" { + return "", fmt.Errorf("fragment %q has no type condition", string(fragmentName)) + } + node, nodeExists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !nodeExists { + return "", fmt.Errorf("missing schema type %q", typeName) + } + if err := r.checkAbstractFragmentCondition(node, possibleSet, parentName); err != nil { + return "", err + } + fragments = append(fragments, abstractSelectionSet{ + condition: node, + selectionSetRef: fragment.SelectionSet, + }) + default: + return "", fmt.Errorf("unsupported selection kind %s", sel.Kind.String()) + } + } + + // Build a branch per concrete implementor. + branches := make([]string, 0, len(possibleNames)) + for _, typeName := range possibleNames { + concreteNode, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists || concreteNode.Kind != ast.NodeKindObjectTypeDefinition { + continue + } + + fields := make([]tsProperty, 0) + + // Bare fields rendered against the concrete type. (For unions there + // will only be aliased __typename here, since other bare fields are + // rejected above.) + for _, fieldRef := range bareFieldRefs { + prop, err := r.fieldProperty(opDoc, fieldRef, concreteNode) + if err != nil { + return "", err + } + fields = append(fields, prop) + } + + // Fragments whose target includes this concrete type. + for _, frag := range fragments { + if !abstractFragmentApplies(frag.condition, typeName, possibleSet, r.schema) { + continue + } + fragTyp, err := r.selectionSetType(opDoc, frag.selectionSetRef, concreteNode) + if err != nil { + return "", err + } + fields = append(fields, propertiesFromInlineObject(fragTyp)...) + } + + // __typename literal: prepend if explicitly selected. + if typenameSelected { + literal := tsProperty{name: "__typename", typ: strconv.Quote(typeName)} + fields = append([]tsProperty{literal}, fields...) + } + + // Drop empty branches. + if len(fields) == 0 { + continue + } + + branches = append(branches, writeInlineObject(fields)) + } + + if len(branches) == 0 { + // Every implementor has zero observable fields. Fall back to a single + // empty object so the type checker still sees a valid shape. + return "{}", nil + } + + if len(branches) == 1 { + return branches[0], nil + } + + // Single-shape collapse: every branch identical → one shape. + allEqual := true + for i := 1; i < len(branches); i++ { + if branches[i] != branches[0] { + allEqual = false + break + } + } + if allEqual { + return branches[0], nil + } + + return strings.Join(branches, " | "), nil +} + +// possibleTypeNames returns the concrete object type names that satisfy the +// given abstract parent, in schema declaration order. +func (r operationRenderer) possibleTypeNames(parent ast.Node) []string { + switch parent.Kind { + case ast.NodeKindInterfaceTypeDefinition: + names, _ := r.schema.InterfaceTypeDefinitionImplementedByObjectWithNames(parent.Ref) + return names + case ast.NodeKindUnionTypeDefinition: + names, _ := r.schema.UnionTypeDefinitionMemberTypeNames(parent.Ref) + return names + case ast.NodeKindObjectTypeDefinition: + return []string{r.schema.ObjectTypeDefinitionNameString(parent.Ref)} + } + return nil +} + +// abstractFragmentApplies decides whether a fragment with the given condition +// applies to the concrete branch named typeName under the parent abstract +// (whose possible types are in parentSet). +func abstractFragmentApplies(condition ast.Node, typeName string, parentSet map[string]struct{}, schema *ast.Document) bool { + switch condition.Kind { + case ast.NodeKindObjectTypeDefinition: + return schema.ObjectTypeDefinitionNameString(condition.Ref) == typeName + case ast.NodeKindInterfaceTypeDefinition: + // applies to any T that implements this interface AND is in parentSet. + impls, _ := schema.InterfaceTypeDefinitionImplementedByObjectWithNames(condition.Ref) + for _, name := range impls { + if name == typeName { + if _, ok := parentSet[name]; ok { + return true + } + } + } + return false + case ast.NodeKindUnionTypeDefinition: + members, _ := schema.UnionTypeDefinitionMemberTypeNames(condition.Ref) + for _, name := range members { + if name == typeName { + if _, ok := parentSet[name]; ok { + return true + } + } + } + return false + } + return false +} + +// checkAbstractFragmentCondition rejects fragments whose type condition can +// never apply under the given parent abstract. +func (r operationRenderer) checkAbstractFragmentCondition(condition ast.Node, parentSet map[string]struct{}, parentName string) error { + switch condition.Kind { + case ast.NodeKindObjectTypeDefinition: + name := r.schema.ObjectTypeDefinitionNameString(condition.Ref) + if _, ok := parentSet[name]; !ok { + return fmt.Errorf("type %q is not a possible type of %q", name, parentName) + } + case ast.NodeKindInterfaceTypeDefinition, ast.NodeKindUnionTypeDefinition: + // abstract conditions are always allowed; their target is the + // intersection with the parent's possible types (which may be empty + // — that just means the fragment contributes nothing). + default: + return fmt.Errorf("unsupported fragment type condition %s", condition.Kind.String()) + } + return nil +} diff --git a/router/internal/codemode/tsgen/tsgen.go b/router/internal/codemode/tsgen/tsgen.go new file mode 100644 index 0000000000..ad18f5ab71 --- /dev/null +++ b/router/internal/codemode/tsgen/tsgen.go @@ -0,0 +1,117 @@ +package tsgen + +import ( + "context" + "fmt" + "strings" + + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +const ( + defaultMaxBundleBytes = 64 * 1024 + graphQLErrorAlias = "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };" + responseAlias = "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;" + notNullHelper = "declare function notNull(value: T | null | undefined, message?: string): T;" + compactHelper = "declare function compact(value: T): T;" +) + +type Renderer struct { + Schema *ast.Document + MaxBytes int +} + +func Adapter(schema *ast.Document, maxBytes ...int) storage.Renderer { + limit := defaultMaxBundleBytes + if len(maxBytes) > 0 { + limit = maxBytes[0] + } + + return Renderer{Schema: schema, MaxBytes: limit} +} + +func (r Renderer) Render(_ context.Context, ops []storage.SessionOp, schema *ast.Document) (string, error) { + if schema == nil { + schema = r.Schema + } + return RenderBundle(ops, schema, r.MaxBytes) +} + +func NewOpsFragment(ops []storage.SessionOp, schema *ast.Document) (string, error) { + renderer := operationRenderer{schema: schema} + + blocks := make([]string, 0, len(ops)) + for _, op := range ops { + block, err := renderer.renderOperation(op) + if err != nil { + return "", err + } + blocks = append(blocks, block) + } + + return strings.Join(blocks, "\n\n"), nil +} + +func RenderBundle(ops []storage.SessionOp, schema *ast.Document, maxBytes int) (string, error) { + renderer := operationRenderer{schema: schema} + + blocks := make([]string, 0, len(ops)) + for _, op := range ops { + block, err := renderer.renderOperation(op) + if err != nil { + return "", err + } + blocks = append(blocks, block) + } + + if maxBytes <= 0 { + return renderBundleBlocks(blocks, 0), nil + } + + full := renderBundleBlocks(blocks, 0) + if len([]byte(full)) <= maxBytes { + return full, nil + } + + for omitted := 1; omitted <= len(blocks); omitted++ { + candidate := renderBundleBlocks(blocks[:len(blocks)-omitted], omitted) + if len([]byte(candidate)) <= maxBytes { + return candidate, nil + } + } + + return "", fmt.Errorf("render TypeScript bundle: maxBytes %d is too small for bundle prelude", maxBytes) +} + +func renderBundleBlocks(blocks []string, omitted int) string { + var b strings.Builder + b.WriteString(graphQLErrorAlias) + b.WriteByte('\n') + b.WriteString(responseAlias) + b.WriteString("\n\n") + + if len(blocks) == 0 { + b.WriteString("declare const tools: {};") + } else { + b.WriteString("declare const tools: {\n") + for i, block := range blocks { + if i > 0 { + b.WriteString("\n\n") + } + b.WriteString(indentBlock(block, " ")) + } + b.WriteString("\n};") + } + + b.WriteString("\n\n") + b.WriteString(notNullHelper) + b.WriteByte('\n') + b.WriteString(compactHelper) + + if omitted > 0 { + fmt.Fprintf(&b, "\n// truncated: %d ops omitted", omitted) + } + + return b.String() +} diff --git a/router/internal/codemode/tsgen/tsgen_test.go b/router/internal/codemode/tsgen/tsgen_test.go new file mode 100644 index 0000000000..9e2dc2e626 --- /dev/null +++ b/router/internal/codemode/tsgen/tsgen_test.go @@ -0,0 +1,411 @@ +package tsgen + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" +) + +const testSchemaSDL = ` +schema { + query: Query + mutation: Mutation +} + +type Query { + health: String! + node(id: ID!): User + search(cursor: String): SearchConnection! + tagged(tags: [String!]!): [User!]! + byStatus(status: Status!): [User!]! + filterUsers(filter: UserFilter): [User!]! + viewer: User + animal: Animal + pet: Pet + pets: [Pet!]! + maybePet: Pet + maybePets: [Pet] + requiredPets: [Pet!]! + searchResult: SearchResult + outsider: Outsider +} + +type Mutation { + renameUser(id: ID!, name: String!): User! +} + +type User { + id: ID! + name: String! + friend: User + tags: [String!]! +} + +type SearchConnection { + nodes: [User]! + nextCursor: String +} + +interface Animal { + id: ID! +} + +type Cat implements Animal & Pet & Friendly { + id: ID! + name: String! + friendliness: Int! + companion: Animal +} + +type Dog implements Pet & Friendly { + id: ID! + bark: String! + friendliness: Int! +} + +type Mouse implements Pet { + id: ID! + squeak: Boolean! +} + +interface Pet { + id: ID! +} + +interface Friendly { + friendliness: Int! +} + +interface Unrelated { + unrelated: String! +} + +type Outsider implements Unrelated { + id: ID! + unrelated: String! +} + +union SearchResult = User | Cat + +enum Status { + OPEN + CLOSED +} + +input UserFilter { + status: Status + tags: [String!] + limit: Int! +} +` + +func testSchema(t *testing.T) *ast.Document { + t.Helper() + + doc, report := astparser.ParseGraphqlDocumentString(testSchemaSDL) + require.False(t, report.HasErrors(), report.Error()) + require.NoError(t, asttransform.MergeDefinitionWithBaseSchema(&doc)) + + return &doc +} + +func TestNewOpsFragmentSignatures(t *testing.T) { + schema := testSchema(t) + + tests := []struct { + name string + op storage.SessionOp + want string + }{ + { + name: "var-less query", + op: storage.SessionOp{ + Name: "health", + Body: `query Health { health }`, + Kind: storage.OperationKindQuery, + Description: "Checks router health.", + }, + want: "/** Checks router health. */\nhealth(): R<{ health: string }>;", + }, + { + name: "required scalar var", + op: storage.SessionOp{ + Name: "getNode", + Body: `query GetNode($id: ID!) { node(id: $id) { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches a node.", + }, + want: "/** Fetches a node. */\ngetNode(vars: { id: string }): R<{ node: { id: string } | null }>;", + }, + { + name: "optional nullable var", + op: storage.SessionOp{ + Name: "search", + Body: `query Search($cursor: String) { search(cursor: $cursor) { nextCursor } }`, + Kind: storage.OperationKindQuery, + Description: "Searches users.", + }, + want: "/** Searches users. */\nsearch(vars?: { cursor?: string | null }): R<{ search: { nextCursor: string | null } }>;", + }, + { + name: "list non-null var", + op: storage.SessionOp{ + Name: "tagged", + Body: `query Tagged($tags: [String!]!) { tagged(tags: $tags) { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches users by tag.", + }, + want: "/** Fetches users by tag. */\ntagged(vars: { tags: string[] }): R<{ tagged: { id: string }[] }>;", + }, + { + name: "enum var", + op: storage.SessionOp{ + Name: "byStatus", + Body: `query ByStatus($status: Status!) { byStatus(status: $status) { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches users by status.", + }, + want: "/** Fetches users by status. */\nbyStatus(vars: { status: \"OPEN\" | \"CLOSED\" }): R<{ byStatus: { id: string }[] }>;", + }, + { + name: "input object var", + op: storage.SessionOp{ + Name: "filterUsers", + Body: `query FilterUsers($filter: UserFilter) { filterUsers(filter: $filter) { id } }`, + Kind: storage.OperationKindQuery, + Description: "Filters users.", + }, + want: "/** Filters users. */\nfilterUsers(vars?: { filter?: { status?: \"OPEN\" | \"CLOSED\" | null; tags?: string[] | null; limit: number } | null }): R<{ filterUsers: { id: string }[] }>;", + }, + { + name: "nested object", + op: storage.SessionOp{ + Name: "viewer", + Body: `query Viewer { viewer { id friend { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches viewer.", + }, + want: "/** Fetches viewer. */\nviewer(): R<{ viewer: { id: string; friend: { name: string } | null } | null }>;", + }, + { + name: "aliased field", + op: storage.SessionOp{ + Name: "viewerAlias", + Body: `query ViewerAlias { me: viewer { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches viewer with alias.", + }, + want: "/** Fetches viewer with alias. */\nviewerAlias(): R<{ me: { id: string } | null }>;", + }, + { + name: "inline fragment", + op: storage.SessionOp{ + Name: "viewerFragment", + Body: `query ViewerFragment { viewer { id ... on User { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches viewer fields.", + }, + want: "/** Fetches viewer fields. */\nviewerFragment(): R<{ viewer: { id: string; name: string } | null }>;", + }, + { + name: "union or interface output", + op: storage.SessionOp{ + Name: "animal", + Body: `query Animal { animal { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches animal.", + }, + want: "/** Fetches animal. */\nanimal(): R<{ animal: { id: string } | null }>;", + }, + { + name: "mutation kind", + op: storage.SessionOp{ + Name: "renameUser", + Body: `mutation RenameUser($id: ID!, $name: String!) { renameUser(id: $id, name: $name) { id name } }`, + Kind: storage.OperationKindMutation, + Description: "Renames a user.", + }, + want: "/** Renames a user. */\nrenameUser(vars: { id: string; name: string }): R<{ renameUser: { id: string; name: string } }>;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewOpsFragment([]storage.SessionOp{tt.op}, schema) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestNewOpsFragmentAbstractSelections(t *testing.T) { + schema := testSchema(t) + + tests := []struct { + name string + op storage.SessionOp + want string + wantErr string + }{ + { + name: "interface, only __typename", + op: storage.SessionOp{ + Name: "petKind", + Body: `query PetKind { pet { __typename } }`, + Kind: storage.OperationKindQuery, + Description: "Pet kind.", + }, + want: "/** Pet kind. */\npetKind(): R<{ pet: { __typename: \"Cat\" } | { __typename: \"Dog\" } | { __typename: \"Mouse\" } | null }>;", + }, + { + name: "interface, bare field + one concrete fragment", + op: storage.SessionOp{ + Name: "petWithCatName", + Body: `query PetWithCatName { pet { id ... on Cat { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet with cat name.", + }, + want: "/** Pet with cat name. */\npetWithCatName(): R<{ pet: { id: string; name: string } | { id: string } | { id: string } | null }>;", + }, + { + name: "interface, fragment on the same interface", + op: storage.SessionOp{ + Name: "petSameInterface", + Body: `query PetSameInterface { pet { ... on Pet { id } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet same interface.", + }, + want: "/** Pet same interface. */\npetSameInterface(): R<{ pet: { id: string } | null }>;", + }, + { + name: "interface, fragment on an unrelated abstract", + op: storage.SessionOp{ + Name: "petUnrelated", + Body: `query PetUnrelated { pet { id ... on Unrelated { unrelated } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet unrelated.", + }, + want: "/** Pet unrelated. */\npetUnrelated(): R<{ pet: { id: string } | null }>;", + }, + { + name: "interface, fragment on a related abstract", + op: storage.SessionOp{ + Name: "petFriendly", + Body: `query PetFriendly { pet { id ... on Friendly { friendliness } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet friendly.", + }, + want: "/** Pet friendly. */\npetFriendly(): R<{ pet: { id: string; friendliness: number } | { id: string; friendliness: number } | { id: string } | null }>;", + }, + { + name: "concrete fragment on a non-implementor type", + op: storage.SessionOp{ + Name: "petBadFragment", + Body: `query PetBadFragment { pet { ... on User { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet with non-implementor fragment.", + }, + wantErr: `render op "petBadFragment": type "User" is not a possible type of "Pet"`, + }, + { + name: "union, __typename-only selection", + op: storage.SessionOp{ + Name: "searchKind", + Body: `query SearchKind { searchResult { __typename } }`, + Kind: storage.OperationKindQuery, + Description: "Search kind.", + }, + want: "/** Search kind. */\nsearchKind(): R<{ searchResult: { __typename: \"User\" } | { __typename: \"Cat\" } | null }>;", + }, + { + name: "union with ... on Member for a subset", + op: storage.SessionOp{ + Name: "searchSubset", + Body: `query SearchSubset { searchResult { __typename ... on Cat { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Search subset.", + }, + want: "/** Search subset. */\nsearchSubset(): R<{ searchResult: { __typename: \"User\" } | { __typename: \"Cat\"; name: string } | null }>;", + }, + { + name: "named fragment spread on abstract field", + op: storage.SessionOp{ + Name: "petSpread", + Body: `query PetSpread { pet { ...Bits } } fragment Bits on Pet { id }`, + Kind: storage.OperationKindQuery, + Description: "Pet spread.", + }, + want: "/** Pet spread. */\npetSpread(): R<{ pet: { id: string } | null }>;", + }, + { + name: "aliased __typename", + op: storage.SessionOp{ + Name: "petAliasedKind", + Body: `query PetAliasedKind { pet { kind: __typename } }`, + Kind: storage.OperationKindQuery, + Description: "Pet aliased kind.", + }, + want: "/** Pet aliased kind. */\npetAliasedKind(): R<{ pet: { kind: string } | null }>;", + }, + { + name: "duplicate response keys, identical", + op: storage.SessionOp{ + Name: "petDupIdentical", + Body: `query PetDupIdentical { pet { id id } }`, + Kind: storage.OperationKindQuery, + Description: "Pet dup identical.", + }, + // merging is out of scope for this PR; pin duplicates as duplicates + want: "/** Pet dup identical. */\npetDupIdentical(): R<{ pet: { id: string; id: string } | null }>;", + }, + { + name: "duplicate response keys, conflicting", + op: storage.SessionOp{ + Name: "petDupConflict", + Body: `query PetDupConflict { pet { id ... on Cat { id: name } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet dup conflict.", + }, + // merging is out of scope; conflicting duplicates are emitted as-is + // instead of erroring (mirrors current object-selection behavior). + want: "/** Pet dup conflict. */\npetDupConflict(): R<{ pet: { id: string; id: string } | { id: string } | { id: string } | null }>;", + }, + { + name: "nested abstract inside an inline fragment", + op: storage.SessionOp{ + Name: "petCompanion", + Body: `query PetCompanion { pet { ... on Cat { companion { __typename } } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet companion.", + }, + want: "/** Pet companion. */\npetCompanion(): R<{ pet: { companion: { __typename: \"Cat\" } | null } | null }>;", + }, + { + name: "list / nullable / non-nullable wrapping", + op: storage.SessionOp{ + Name: "petsWrappers", + Body: `query PetsWrappers { pets { id } maybePet { id } maybePets { id } requiredPets { id } }`, + Kind: storage.OperationKindQuery, + Description: "Pets wrappers.", + }, + want: "/** Pets wrappers. */\npetsWrappers(): R<{ pets: { id: string }[]; maybePet: { id: string } | null; maybePets: ({ id: string } | null)[] | null; requiredPets: { id: string }[] }>;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewOpsFragment([]storage.SessionOp{tt.op}, schema) + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/router/internal/codemode/tsgen/typescript.go b/router/internal/codemode/tsgen/typescript.go new file mode 100644 index 0000000000..3807f6c7a7 --- /dev/null +++ b/router/internal/codemode/tsgen/typescript.go @@ -0,0 +1,102 @@ +package tsgen + +import ( + "strconv" + "strings" +) + +type tsProperty struct { + name string + typ string + optional bool +} + +func writeJSDoc(description string) string { + clean := strings.Join(strings.Fields(description), " ") + clean = strings.ReplaceAll(clean, "*/", "* /") + if clean == "" { + clean = "Registered GraphQL operation." + } + return "/** " + clean + " */" +} + +func writeFieldSignature(description, name, varsType, outputType string, varsOptional bool) string { + var b strings.Builder + b.WriteString(writeJSDoc(description)) + b.WriteByte('\n') + b.WriteString(name) + if varsType == "{}" { + b.WriteString("()") + } else { + b.WriteString("(vars") + if varsOptional { + b.WriteByte('?') + } + b.WriteString(": ") + b.WriteString(varsType) + b.WriteByte(')') + } + b.WriteString(": R<") + b.WriteString(outputType) + b.WriteString(">;") + return b.String() +} + +func writeInlineObject(fields []tsProperty) string { + if len(fields) == 0 { + return "{}" + } + + parts := make([]string, 0, len(fields)) + for _, field := range fields { + suffix := ": " + if field.optional { + suffix = "?: " + } + parts = append(parts, field.name+suffix+field.typ) + } + + return "{ " + strings.Join(parts, "; ") + " }" +} + +func writeArray(item string) string { + if strings.Contains(item, " | ") { + item = "(" + item + ")" + } + return item + "[]" +} + +func writeNullable(typ string) string { + if strings.HasSuffix(typ, " | null") { + return typ + } + return typ + " | null" +} + +func writeStringLiteralUnion(values []string) string { + if len(values) == 0 { + return "unknown" + } + + quoted := make([]string, 0, len(values)) + for _, value := range values { + quoted = append(quoted, strconv.Quote(value)) + } + + return strings.Join(quoted, " | ") +} + +func indentBlock(block, indent string) string { + if block == "" { + return "" + } + + lines := strings.Split(block, "\n") + for i := range lines { + if lines[i] != "" { + lines[i] = indent + lines[i] + } + } + + return strings.Join(lines, "\n") +} diff --git a/router/internal/codemode/yoko/client.go b/router/internal/codemode/yoko/client.go new file mode 100644 index 0000000000..d89b09f7bd --- /dev/null +++ b/router/internal/codemode/yoko/client.go @@ -0,0 +1,158 @@ +package yoko + +import ( + "context" + "net/http" + "sync" + + "connectrpc.com/connect" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + yokoconnect "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" +) + +type Option func(*Client) + +func WithServiceClient(serviceClient yokoconnect.YokoServiceClient) Option { + return func(c *Client) { + if serviceClient != nil { + c.serviceClient = serviceClient + } + } +} + +type Client struct { + serviceClient yokoconnect.YokoServiceClient + logger *zap.Logger + + schemaMu sync.RWMutex + schemaSDL string + schemaID string + + indexGroup singleflight.Group +} + +func New(httpClient *http.Client, baseURL string, logger *zap.Logger, opts ...Option) *Client { + if httpClient == nil { + httpClient = http.DefaultClient + } + if logger == nil { + logger = zap.NewNop() + } + + client := &Client{ + serviceClient: yokoconnect.NewYokoServiceClient(httpClient, baseURL), + logger: logger, + } + for _, opt := range opts { + opt(client) + } + return client +} + +func (c *Client) Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { + schemaID, err := c.ensureSchemaID(ctx) + if err != nil { + return nil, err + } + + resp, err := c.search(ctx, schemaID, sessionID, prompts) + if err == nil { + return resp, nil + } + if connect.CodeOf(err) != connect.CodeNotFound { + return nil, err + } + + c.invalidateSchemaID(schemaID) + + schemaID, err = c.ensureSchemaID(ctx) + if err != nil { + return nil, err + } + + resp, err = c.search(ctx, schemaID, sessionID, prompts) + if err != nil { + c.invalidateSchemaID(schemaID) + return nil, err + } + return resp, nil +} + +func (c *Client) SetSchema(sdl string) { + c.schemaMu.Lock() + defer c.schemaMu.Unlock() + c.schemaSDL = sdl + c.schemaID = "" +} + +func (c *Client) Schema() string { + c.schemaMu.RLock() + defer c.schemaMu.RUnlock() + return c.schemaSDL +} + +func (c *Client) ensureSchemaID(ctx context.Context) (string, error) { + sdl, schemaID := c.schemaState() + if schemaID != "" { + return schemaID, nil + } + + // Key by raw SDL because Yoko, not the router, owns schema identity. + value, err, _ := c.indexGroup.Do(sdl, func() (any, error) { + currentSDL, currentSchemaID := c.schemaState() + if currentSDL == sdl && currentSchemaID != "" { + return currentSchemaID, nil + } + + resp, err := c.serviceClient.Index(ctx, connect.NewRequest(&yokov1.IndexRequest{ + SchemaSdl: sdl, + })) + if err != nil { + return "", err + } + + indexedSchemaID := resp.Msg.GetSchemaId() + c.cacheSchemaID(currentSDL, indexedSchemaID) + return indexedSchemaID, nil + }) + if err != nil { + return "", err + } + return value.(string), nil +} + +func (c *Client) search(ctx context.Context, schemaID string, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { + resp, err := c.serviceClient.Search(ctx, connect.NewRequest(&yokov1.SearchRequest{ + Prompts: prompts, + SchemaId: schemaID, + SessionId: sessionID, + })) + if err != nil { + return nil, err + } + return resp.Msg, nil +} + +func (c *Client) schemaState() (string, string) { + c.schemaMu.RLock() + defer c.schemaMu.RUnlock() + return c.schemaSDL, c.schemaID +} + +func (c *Client) cacheSchemaID(sdl string, schemaID string) { + c.schemaMu.Lock() + defer c.schemaMu.Unlock() + if c.schemaSDL == sdl { + c.schemaID = schemaID + } +} + +func (c *Client) invalidateSchemaID(schemaID string) { + c.schemaMu.Lock() + defer c.schemaMu.Unlock() + if c.schemaID == schemaID { + c.schemaID = "" + } +} diff --git a/router/internal/codemode/yoko/client_test.go b/router/internal/codemode/yoko/client_test.go new file mode 100644 index 0000000000..136e5193c8 --- /dev/null +++ b/router/internal/codemode/yoko/client_test.go @@ -0,0 +1,434 @@ +package yoko + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" +) + +type fakeYokoServiceClient struct { + mu sync.Mutex + + indexRequests []*yokov1.IndexRequest + searchRequests []*yokov1.SearchRequest + + indexFunc func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) + searchFunc func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) +} + +func (f *fakeYokoServiceClient) Index(ctx context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + f.mu.Lock() + f.indexRequests = append(f.indexRequests, req.Msg) + indexFunc := f.indexFunc + f.mu.Unlock() + + if indexFunc != nil { + return indexFunc(ctx, req) + } + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-1"}), nil +} + +func (f *fakeYokoServiceClient) Search(ctx context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + f.mu.Lock() + f.searchRequests = append(f.searchRequests, req.Msg) + searchFunc := f.searchFunc + f.mu.Unlock() + + if searchFunc != nil { + return searchFunc(ctx, req) + } + return connect.NewResponse(searchResponse("op")), nil +} + +func (f *fakeYokoServiceClient) indexRequestMessages() []*yokov1.IndexRequest { + f.mu.Lock() + defer f.mu.Unlock() + return append([]*yokov1.IndexRequest(nil), f.indexRequests...) +} + +func (f *fakeYokoServiceClient) searchRequestMessages() []*yokov1.SearchRequest { + f.mu.Lock() + defer f.mu.Unlock() + return append([]*yokov1.SearchRequest(nil), f.searchRequests...) +} + +func newTestClient(fake *fakeYokoServiceClient) *Client { + client := New(nil, "http://yoko.example", nil, WithServiceClient(fake)) + client.SetSchema("type Query { product: Product }") + return client +} + +func searchResponse(name string) *yokov1.SearchResponse { + return &yokov1.SearchResponse{ + Operations: []*yokov1.GeneratedOperation{ + { + Name: name, + Body: "query " + name + " { product { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch product", + }, + }, + } +} + +func connectError(code connect.Code, message string) error { + return connect.NewError(code, errors.New(message)) +} + +func TestSearchFirstCallIndexesSchemaThenSearchesWithReturnedID(t *testing.T) { + fake := &fakeYokoServiceClient{ + indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-from-yoko"}), nil + }, + searchFunc: func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + return connect.NewResponse(searchResponse("fromSearch")), nil + }, + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.NoError(t, err) + require.Equal(t, searchResponse("fromSearch"), actual) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"find products"}, + SchemaId: "schema-from-yoko", + SessionId: "session-1", + }, + }, fake.searchRequestMessages()) +} + +func TestSearchSubsequentCallUsesCachedSchemaID(t *testing.T) { + fake := &fakeYokoServiceClient{} + client := newTestClient(fake) + + first, firstErr := client.Search(context.Background(), "session-1", []string{"first"}) + second, secondErr := client.Search(context.Background(), "session-2", []string{"second"}) + + require.NoError(t, firstErr) + require.NoError(t, secondErr) + require.Equal(t, searchResponse("op"), first) + require.Equal(t, searchResponse("op"), second) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"first"}, + SchemaId: "schema-1", + SessionId: "session-1", + }, + { + Prompts: []string{"second"}, + SchemaId: "schema-1", + SessionId: "session-2", + }, + }, fake.searchRequestMessages()) +} + +func TestSearchReindexesAndRetriesOnceAfterNotFound(t *testing.T) { + var searchCount int + fake := &fakeYokoServiceClient{} + indexIDs := []string{"schema-initial", "schema-reindexed"} + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + id := indexIDs[len(fake.indexRequestMessages())-1] + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + searchCount++ + if searchCount == 1 { + return nil, connectError(connect.CodeNotFound, "schema evicted") + } + return connect.NewResponse(searchResponse("retried")), nil + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.NoError(t, err) + require.Equal(t, searchResponse("retried"), actual) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"find products"}, + SchemaId: "schema-initial", + SessionId: "session-1", + }, + { + Prompts: []string{"find products"}, + SchemaId: "schema-reindexed", + SessionId: "session-1", + }, + }, fake.searchRequestMessages()) +} + +func TestSearchRetryFailureSurfacesErrorAndLeavesCacheEmpty(t *testing.T) { + retryErr := connectError(connect.CodeUnavailable, "retry transport down") + indexIDs := []string{"schema-initial", "schema-reindexed", "schema-after-failure"} + fake := &fakeYokoServiceClient{} + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + id := indexIDs[len(fake.indexRequestMessages())-1] + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + searchErrors := []error{ + connectError(connect.CodeNotFound, "schema evicted"), + retryErr, + nil, + } + fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + err := searchErrors[len(fake.searchRequestMessages())-1] + if err != nil { + return nil, err + } + return connect.NewResponse(searchResponse("afterFailure")), nil + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.Nil(t, actual) + require.ErrorIs(t, err, retryErr) + + actualAfterFailure, errAfterFailure := client.Search(context.Background(), "session-2", []string{"find products again"}) + + require.NoError(t, errAfterFailure) + require.Equal(t, searchResponse("afterFailure"), actualAfterFailure) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"find products"}, + SchemaId: "schema-initial", + SessionId: "session-1", + }, + { + Prompts: []string{"find products"}, + SchemaId: "schema-reindexed", + SessionId: "session-1", + }, + { + Prompts: []string{"find products again"}, + SchemaId: "schema-after-failure", + SessionId: "session-2", + }, + }, fake.searchRequestMessages()) +} + +func TestSearchRetryNotFoundSurfacesErrorAndLeavesCacheEmpty(t *testing.T) { + retryErr := connectError(connect.CodeNotFound, "schema evicted again") + indexIDs := []string{"schema-initial", "schema-reindexed", "schema-after-failure"} + fake := &fakeYokoServiceClient{} + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + id := indexIDs[len(fake.indexRequestMessages())-1] + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + searchErrors := []error{ + connectError(connect.CodeNotFound, "schema evicted"), + retryErr, + nil, + } + fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + err := searchErrors[len(fake.searchRequestMessages())-1] + if err != nil { + return nil, err + } + return connect.NewResponse(searchResponse("afterFailure")), nil + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.Nil(t, actual) + require.ErrorIs(t, err, retryErr) + + actualAfterFailure, errAfterFailure := client.Search(context.Background(), "session-2", []string{"find products again"}) + + require.NoError(t, errAfterFailure) + require.Equal(t, searchResponse("afterFailure"), actualAfterFailure) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) +} + +func TestSetSchemaInvalidatesCachedIDAndNextSearchReindexes(t *testing.T) { + indexIDs := []string{"schema-v1", "schema-v2"} + fake := &fakeYokoServiceClient{} + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + id := indexIDs[len(fake.indexRequestMessages())-1] + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + client := newTestClient(fake) + + _, firstErr := client.Search(context.Background(), "session-1", []string{"first"}) + client.SetSchema("type Query { review: Review }") + _, secondErr := client.Search(context.Background(), "session-2", []string{"second"}) + + require.NoError(t, firstErr) + require.NoError(t, secondErr) + require.Equal(t, "type Query { review: Review }", client.Schema()) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { review: Review }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"first"}, + SchemaId: "schema-v1", + SessionId: "session-1", + }, + { + Prompts: []string{"second"}, + SchemaId: "schema-v2", + SessionId: "session-2", + }, + }, fake.searchRequestMessages()) +} + +func TestConcurrentFirstSearchIndexesOnce(t *testing.T) { + indexStarted := make(chan struct{}) + releaseIndex := make(chan struct{}) + var indexStartedOnce sync.Once + fake := &fakeYokoServiceClient{ + indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + indexStartedOnce.Do(func() { + close(indexStarted) + }) + <-releaseIndex + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-shared"}), nil + }, + } + client := newTestClient(fake) + + var wg sync.WaitGroup + wg.Add(2) + results := make([]*yokov1.SearchResponse, 2) + errs := make([]error, 2) + go func() { + defer wg.Done() + results[0], errs[0] = client.Search(context.Background(), "session-1", []string{"first"}) + }() + <-indexStarted + go func() { + defer wg.Done() + results[1], errs[1] = client.Search(context.Background(), "session-2", []string{"second"}) + }() + time.Sleep(25 * time.Millisecond) + close(releaseIndex) + wg.Wait() + + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + require.Equal(t, searchResponse("op"), results[0]) + require.Equal(t, searchResponse("op"), results[1]) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + assert.Equal(t, 2, len(fake.searchRequestMessages())) +} + +func TestConcurrentFirstSearchIndexFailureReturnsErrorToBothAndLeavesCacheEmpty(t *testing.T) { + indexErr := connectError(connect.CodeUnavailable, "index unavailable") + indexStarted := make(chan struct{}) + releaseIndex := make(chan struct{}) + var indexStartedOnce sync.Once + fake := &fakeYokoServiceClient{ + indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + indexStartedOnce.Do(func() { + close(indexStarted) + }) + <-releaseIndex + return nil, indexErr + }, + } + client := newTestClient(fake) + + var wg sync.WaitGroup + wg.Add(2) + results := make([]*yokov1.SearchResponse, 2) + errs := make([]error, 2) + go func() { + defer wg.Done() + results[0], errs[0] = client.Search(context.Background(), "session-1", []string{"first"}) + }() + <-indexStarted + go func() { + defer wg.Done() + results[1], errs[1] = client.Search(context.Background(), "session-2", []string{"second"}) + }() + time.Sleep(25 * time.Millisecond) + close(releaseIndex) + wg.Wait() + + require.Nil(t, results[0]) + require.Nil(t, results[1]) + require.ErrorIs(t, errs[0], indexErr) + require.ErrorIs(t, errs[1], indexErr) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest(nil), fake.searchRequestMessages()) + + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-after-error"}), nil + } + actual, err := client.Search(context.Background(), "session-3", []string{"third"}) + + require.NoError(t, err) + require.Equal(t, searchResponse("op"), actual) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) +} + +func TestSearchBubblesUpArbitraryConnectErrors(t *testing.T) { + searchErr := connectError(connect.CodeUnavailable, "search unavailable") + fake := &fakeYokoServiceClient{ + searchFunc: func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + return nil, searchErr + }, + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.Nil(t, actual) + require.ErrorIs(t, err, searchErr) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"find products"}, + SchemaId: "schema-1", + SessionId: "session-1", + }, + }, fake.searchRequestMessages()) +} + +func TestSchemaGetterReturnsCurrentSchema(t *testing.T) { + client := New(nil, "http://yoko.example", nil, WithServiceClient(&fakeYokoServiceClient{})) + + require.Equal(t, "", client.Schema()) + client.SetSchema("type Query { store: Store }") + require.Equal(t, "type Query { store: Store }", client.Schema()) +} diff --git a/router/internal/codemode/yoko/searcher.go b/router/internal/codemode/yoko/searcher.go new file mode 100644 index 0000000000..611e5f6fe2 --- /dev/null +++ b/router/internal/codemode/yoko/searcher.go @@ -0,0 +1,15 @@ +package yoko + +import ( + "context" + + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" +) + +type Searcher interface { + Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) + SetSchema(string) + Schema() string +} + +var _ Searcher = (*Client)(nil) diff --git a/router/pkg/config/code_mode_config_test.go b/router/pkg/config/code_mode_config_test.go new file mode 100644 index 0000000000..839c0ab403 --- /dev/null +++ b/router/pkg/config/code_mode_config_test.go @@ -0,0 +1,278 @@ +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMCPCodeModeConfigurationDefaults(t *testing.T) { + f := createTempFileFromFixture(t, ` +version: "1" +`) + + cfg, err := LoadConfig([]string{f}) + require.NoError(t, err) + + assert.Equal(t, MCPCodeModeConfiguration{ + Enabled: false, + Server: MCPCodeModeServerConfig{ListenAddr: "localhost:5027"}, + RequireMutationApproval: true, + ExecuteTimeout: 120 * time.Second, + MaxResultBytes: 32768, + Sandbox: MCPCodeModeSandboxConfig{ + Timeout: 5 * time.Second, + MaxMemoryMB: 16, + MaxInputSizeBytes: 65536, + MaxOutputSizeBytes: 1048576, + }, + QueryGeneration: MCPCodeModeQueryGenConfig{ + Enabled: false, + Endpoint: "", + Timeout: 10 * time.Second, + Auth: MCPCodeModeQueryGenAuthConfig{ + Type: "static", + StaticToken: "", + TokenEndpoint: "", + ClientID: "", + ClientSecret: "", + }, + }, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: false, + SessionTTL: 30 * time.Minute, + MaxSessions: 1000, + MaxBundleBytes: 262144, + Storage: MCPCodeModeNamedOpsStorageConfig{ + ProviderID: "", + KeyPrefix: "cosmo_code_mode", + }, + }, + }, cfg.Config.MCP.CodeMode) +} + +func TestMCPCodeModeConfigurationFullYAMLOverride(t *testing.T) { + f := createTempFileFromFixture(t, ` +version: "1" + +mcp: + session: + stateless: false + code_mode: + enabled: true + server: + listen_addr: "0.0.0.0:6027" + require_mutation_approval: false + execute_timeout: "45s" + max_result_bytes: 64000 + sandbox: + timeout: "7s" + max_memory_mb: 32 + max_input_size_bytes: 131072 + max_output_size_bytes: 2097152 + query_generation: + enabled: true + endpoint: "https://yoko.example.com" + timeout: "15s" + auth: + type: "jwt" + static_token: "unused-static" + token_endpoint: "https://auth.example.com/token" + client_id: "router-client" + client_secret: "router-secret" + named_ops: + enabled: true + session_ttl: "45m" + max_sessions: 2000 + max_bundle_bytes: 524288 + storage: + provider_id: "my_redis" + key_prefix: "custom_code_mode" +`) + + cfg, err := LoadConfig([]string{f}) + require.NoError(t, err) + + assert.Equal(t, MCPCodeModeConfiguration{ + Enabled: true, + Server: MCPCodeModeServerConfig{ListenAddr: "0.0.0.0:6027"}, + RequireMutationApproval: false, + ExecuteTimeout: 45 * time.Second, + MaxResultBytes: 64000, + Sandbox: MCPCodeModeSandboxConfig{ + Timeout: 7 * time.Second, + MaxMemoryMB: 32, + MaxInputSizeBytes: 131072, + MaxOutputSizeBytes: 2097152, + }, + QueryGeneration: MCPCodeModeQueryGenConfig{ + Enabled: true, + Endpoint: "https://yoko.example.com", + Timeout: 15 * time.Second, + Auth: MCPCodeModeQueryGenAuthConfig{ + Type: "jwt", + StaticToken: "unused-static", + TokenEndpoint: "https://auth.example.com/token", + ClientID: "router-client", + ClientSecret: "router-secret", + }, + }, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + SessionTTL: 45 * time.Minute, + MaxSessions: 2000, + MaxBundleBytes: 524288, + Storage: MCPCodeModeNamedOpsStorageConfig{ + ProviderID: "my_redis", + KeyPrefix: "custom_code_mode", + }, + }, + }, cfg.Config.MCP.CodeMode) +} + +func TestMCPCodeModeConfigurationEnvOverride(t *testing.T) { + t.Setenv("MCP_CODE_MODE_ENABLED", "true") + t.Setenv("MCP_CODE_MODE_LISTEN_ADDR", "127.0.0.1:6027") + t.Setenv("MCP_CODE_MODE_REQUIRE_MUTATION_APPROVAL", "false") + t.Setenv("MCP_CODE_MODE_EXECUTE_TIMEOUT", "30s") + t.Setenv("MCP_CODE_MODE_MAX_RESULT_BYTES", "49152") + t.Setenv("MCP_CODE_MODE_SANDBOX_TIMEOUT", "8s") + t.Setenv("MCP_CODE_MODE_SANDBOX_MAX_MEMORY_MB", "64") + t.Setenv("MCP_CODE_MODE_SANDBOX_MAX_INPUT_SIZE_BYTES", "262144") + t.Setenv("MCP_CODE_MODE_SANDBOX_MAX_OUTPUT_SIZE_BYTES", "3145728") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_ENABLED", "true") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_ENDPOINT", "https://env-yoko.example.com") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_TIMEOUT", "20s") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_TYPE", "jwt") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_STATIC_TOKEN", "env-static-token") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_TOKEN_ENDPOINT", "https://env-auth.example.com/token") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_CLIENT_ID", "env-client") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_CLIENT_SECRET", "env-secret") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_ENABLED", "true") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_SESSION_TTL", "1h") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_MAX_SESSIONS", "3000") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_MAX_BUNDLE_BYTES", "1048576") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_STORAGE_PROVIDER_ID", "env_redis") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_STORAGE_KEY_PREFIX", "env_code_mode") + + f := createTempFileFromFixture(t, ` +version: "1" + +mcp: + session: + stateless: false +`) + + cfg, err := LoadConfig([]string{f}) + require.NoError(t, err) + + assert.Equal(t, MCPCodeModeConfiguration{ + Enabled: true, + Server: MCPCodeModeServerConfig{ListenAddr: "127.0.0.1:6027"}, + RequireMutationApproval: false, + ExecuteTimeout: 30 * time.Second, + MaxResultBytes: 49152, + Sandbox: MCPCodeModeSandboxConfig{ + Timeout: 8 * time.Second, + MaxMemoryMB: 64, + MaxInputSizeBytes: 262144, + MaxOutputSizeBytes: 3145728, + }, + QueryGeneration: MCPCodeModeQueryGenConfig{ + Enabled: true, + Endpoint: "https://env-yoko.example.com", + Timeout: 20 * time.Second, + Auth: MCPCodeModeQueryGenAuthConfig{ + Type: "jwt", + StaticToken: "env-static-token", + TokenEndpoint: "https://env-auth.example.com/token", + ClientID: "env-client", + ClientSecret: "env-secret", + }, + }, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + SessionTTL: time.Hour, + MaxSessions: 3000, + MaxBundleBytes: 1048576, + Storage: MCPCodeModeNamedOpsStorageConfig{ + ProviderID: "env_redis", + KeyPrefix: "env_code_mode", + }, + }, + }, cfg.Config.MCP.CodeMode) +} + +func TestValidateMCPCodeMode(t *testing.T) { + tests := []struct { + name string + cfg MCPCodeModeConfiguration + sessionStateless bool + wantErr string + }{ + { + name: "code mode disabled skips validation", + cfg: MCPCodeModeConfiguration{ + Enabled: false, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + }, + }, + }, + { + name: "named ops disabled skips validation", + cfg: MCPCodeModeConfiguration{ + Enabled: true, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: false, + }, + }, + }, + { + name: "memory backend (no provider_id) is valid", + cfg: MCPCodeModeConfiguration{ + Enabled: true, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + Storage: MCPCodeModeNamedOpsStorageConfig{KeyPrefix: "cosmo_code_mode"}, + }, + }, + }, + { + name: "redis-backed (provider_id set) is valid", + cfg: MCPCodeModeConfiguration{ + Enabled: true, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + Storage: MCPCodeModeNamedOpsStorageConfig{ + ProviderID: "my_redis", + KeyPrefix: "cosmo_code_mode", + }, + }, + }, + }, + { + name: "stateless named ops does not fail boot validation", + cfg: MCPCodeModeConfiguration{ + Enabled: true, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + }, + }, + sessionStateless: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateMCPCodeMode(&tt.cfg, tt.sessionStateless) + if tt.wantErr == "" { + require.NoError(t, err) + return + } + require.EqualError(t, err, tt.wantErr) + }) + } +} diff --git a/router/pkg/config/code_mode_validation.go b/router/pkg/config/code_mode_validation.go new file mode 100644 index 0000000000..5039ec9a4e --- /dev/null +++ b/router/pkg/config/code_mode_validation.go @@ -0,0 +1,23 @@ +package config + +func ValidateMCPCodeMode(cfg *MCPCodeModeConfiguration, sessionStateless bool) error { + if !cfg.Enabled { + return nil + } + + if !cfg.NamedOps.Enabled { + return nil + } + + // Storage backend selection: when ProviderID is set, the router resolves it + // against the central storage_providers registry (Redis backend). Otherwise + // the in-memory backend is used. The provider lookup error (unknown id) is + // emitted by the router at startup, not here. + + // Named ops require stateful MCP sessions to work, but this intentionally + // does not fail boot. The Code Mode runtime emits the warn log on first + // reload so deployments can enable Code Mode before flipping session mode. + _ = sessionStateless + + return nil +} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 22f58cf72a..fb709303ea 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -1142,15 +1142,16 @@ type CacheWarmupConfiguration struct { } type MCPConfiguration struct { - Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` - Server MCPServer `yaml:"server,omitempty"` - Storage MCPStorageConfig `yaml:"storage,omitempty"` - Session MCPSessionConfig `yaml:"session,omitempty"` - GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` - ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` - EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` - ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` - RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` + Server MCPServer `yaml:"server,omitempty"` + Storage MCPStorageConfig `yaml:"storage,omitempty"` + Session MCPSessionConfig `yaml:"session,omitempty"` + CodeMode MCPCodeModeConfiguration `yaml:"code_mode,omitempty" envPrefix:"MCP_CODE_MODE_"` + GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` + ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` + EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` + ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` + RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` // OmitToolNamePrefix removes the "execute_operation_" prefix from MCP tool names. // When enabled, GetUser becomes get_user. When disabled (default), GetUser becomes execute_operation_get_user. OmitToolNamePrefix bool `yaml:"omit_tool_name_prefix" envDefault:"false" env:"MCP_OMIT_TOOL_NAME_PREFIX"` @@ -1203,6 +1204,56 @@ type MCPSessionConfig struct { Stateless bool `yaml:"stateless" envDefault:"true" env:"MCP_SESSION_STATELESS"` } +type MCPCodeModeConfiguration struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + Server MCPCodeModeServerConfig `yaml:"server,omitempty" envPrefix:""` + RequireMutationApproval bool `yaml:"require_mutation_approval" envDefault:"true" env:"REQUIRE_MUTATION_APPROVAL"` + ExecuteTimeout time.Duration `yaml:"execute_timeout" envDefault:"120s" env:"EXECUTE_TIMEOUT"` + MaxResultBytes int `yaml:"max_result_bytes" envDefault:"32768" env:"MAX_RESULT_BYTES"` + Sandbox MCPCodeModeSandboxConfig `yaml:"sandbox,omitempty" envPrefix:"SANDBOX_"` + QueryGeneration MCPCodeModeQueryGenConfig `yaml:"query_generation,omitempty" envPrefix:"QUERY_GENERATION_"` + NamedOps MCPCodeModeNamedOpsConfig `yaml:"named_ops,omitempty" envPrefix:"NAMED_OPS_"` +} + +type MCPCodeModeServerConfig struct { + ListenAddr string `yaml:"listen_addr" envDefault:"localhost:5027" env:"LISTEN_ADDR"` +} + +type MCPCodeModeSandboxConfig struct { + Timeout time.Duration `yaml:"timeout" envDefault:"5s" env:"TIMEOUT"` + MaxMemoryMB int `yaml:"max_memory_mb" envDefault:"16" env:"MAX_MEMORY_MB"` + MaxInputSizeBytes int `yaml:"max_input_size_bytes" envDefault:"65536" env:"MAX_INPUT_SIZE_BYTES"` + MaxOutputSizeBytes int `yaml:"max_output_size_bytes" envDefault:"1048576" env:"MAX_OUTPUT_SIZE_BYTES"` +} + +type MCPCodeModeQueryGenConfig struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + Endpoint string `yaml:"endpoint,omitempty" env:"ENDPOINT"` + Timeout time.Duration `yaml:"timeout" envDefault:"10s" env:"TIMEOUT"` + Auth MCPCodeModeQueryGenAuthConfig `yaml:"auth,omitempty" envPrefix:"AUTH_"` +} + +type MCPCodeModeQueryGenAuthConfig struct { + Type string `yaml:"type" envDefault:"static" env:"TYPE"` + StaticToken string `yaml:"static_token,omitempty" env:"STATIC_TOKEN"` + TokenEndpoint string `yaml:"token_endpoint,omitempty" env:"TOKEN_ENDPOINT"` + ClientID string `yaml:"client_id,omitempty" env:"CLIENT_ID"` + ClientSecret string `yaml:"client_secret,omitempty" env:"CLIENT_SECRET"` +} + +type MCPCodeModeNamedOpsConfig struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + SessionTTL time.Duration `yaml:"session_ttl" envDefault:"30m" env:"SESSION_TTL"` + MaxSessions int `yaml:"max_sessions" envDefault:"1000" env:"MAX_SESSIONS"` + MaxBundleBytes int `yaml:"max_bundle_bytes" envDefault:"262144" env:"MAX_BUNDLE_BYTES"` + Storage MCPCodeModeNamedOpsStorageConfig `yaml:"storage,omitempty" envPrefix:"STORAGE_"` +} + +type MCPCodeModeNamedOpsStorageConfig struct { + ProviderID string `yaml:"provider_id,omitempty" env:"PROVIDER_ID"` + KeyPrefix string `yaml:"key_prefix" envDefault:"cosmo_code_mode" env:"KEY_PREFIX"` +} + type MCPStorageConfig struct { ProviderID string `yaml:"provider_id,omitempty" env:"MCP_STORAGE_PROVIDER_ID"` } @@ -1462,5 +1513,9 @@ func LoadConfig(configFilePaths []string) (*LoadResult, error) { cfg.Config.SubgraphErrorPropagation.AllowedExtensionFields = unique.SliceElements(append(cfg.Config.SubgraphErrorPropagation.AllowedExtensionFields, "code", "stacktrace")) } + if err := ValidateMCPCodeMode(&cfg.Config.MCP.CodeMode, cfg.Config.MCP.Session.Stateless); err != nil { + return nil, err + } + return cfg, nil } diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index e90ae50407..f25531abee 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2396,6 +2396,150 @@ } } }, + "code_mode": { + "type": "object", + "description": "Configuration for the Code Mode MCP server surface.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false + }, + "server": { + "type": "object", + "additionalProperties": false, + "properties": { + "listen_addr": { + "type": "string", + "default": "localhost:5027", + "format": "hostname-port" + } + } + }, + "require_mutation_approval": { + "type": "boolean", + "default": true + }, + "execute_timeout": { + "type": "string", + "default": "120s", + "duration": { + "minimum": "0s" + } + }, + "max_result_bytes": { + "type": "integer", + "default": 32768 + }, + "sandbox": { + "type": "object", + "additionalProperties": false, + "properties": { + "timeout": { + "type": "string", + "default": "5s", + "duration": { + "minimum": "0s" + } + }, + "max_memory_mb": { + "type": "integer", + "default": 16 + }, + "max_input_size_bytes": { + "type": "integer", + "default": 65536 + }, + "max_output_size_bytes": { + "type": "integer", + "default": 1048576 + } + } + }, + "query_generation": { + "type": "object", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false + }, + "endpoint": { + "type": "string" + }, + "timeout": { + "type": "string", + "default": "10s", + "duration": { + "minimum": "0s" + } + }, + "auth": { + "type": "object", + "additionalProperties": false, + "properties": { + "type": { + "type": "string", + "default": "static" + }, + "static_token": { + "type": "string" + }, + "token_endpoint": { + "type": "string" + }, + "client_id": { + "type": "string" + }, + "client_secret": { + "type": "string" + } + } + } + } + }, + "named_ops": { + "type": "object", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false + }, + "session_ttl": { + "type": "string", + "default": "30m", + "duration": { + "minimum": "0s" + } + }, + "max_sessions": { + "type": "integer", + "default": 1000 + }, + "max_bundle_bytes": { + "type": "integer", + "default": 262144 + }, + "storage": { + "type": "object", + "additionalProperties": false, + "properties": { + "provider_id": { + "type": "string", + "description": "ID of an entry in storage_providers.redis used to back named ops. When unset, the in-memory backend is used." + }, + "key_prefix": { + "type": "string", + "default": "cosmo_code_mode", + "description": "Key prefix applied to all named-ops keys written to the Redis storage provider." + } + } + } + } + } + } + }, "graph_name": { "type": "string", "default": "mygraph", diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index 42c6986d38..391d0838a0 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -64,11 +64,43 @@ mcp: omit_tool_name_prefix: false graph_name: cosmo router_url: https://cosmo-router.wundergraph.com + session: + stateless: false server: listen_addr: localhost:5025 base_url: 'http://localhost:5025' storage: provider_id: mcp + code_mode: + enabled: true + server: + listen_addr: localhost:6027 + require_mutation_approval: false + execute_timeout: 45s + max_result_bytes: 64000 + sandbox: + timeout: 7s + max_memory_mb: 32 + max_input_size_bytes: 131072 + max_output_size_bytes: 2097152 + query_generation: + enabled: true + endpoint: https://yoko.example.com + timeout: 15s + auth: + type: jwt + static_token: static-token + token_endpoint: https://auth.example.com/token + client_id: router-client + client_secret: router-secret + named_ops: + enabled: true + session_ttl: 45m + max_sessions: 2000 + max_bundle_bytes: 524288 + storage: + provider_id: my_redis + key_prefix: custom_code_mode watch_config: enabled: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 8dc81bf6ed..7e892e1c27 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -159,6 +159,43 @@ "Session": { "Stateless": true }, + "CodeMode": { + "Enabled": false, + "Server": { + "ListenAddr": "localhost:5027" + }, + "RequireMutationApproval": true, + "ExecuteTimeout": 120000000000, + "MaxResultBytes": 32768, + "Sandbox": { + "Timeout": 5000000000, + "MaxMemoryMB": 16, + "MaxInputSizeBytes": 65536, + "MaxOutputSizeBytes": 1048576 + }, + "QueryGeneration": { + "Enabled": false, + "Endpoint": "", + "Timeout": 10000000000, + "Auth": { + "Type": "static", + "StaticToken": "", + "TokenEndpoint": "", + "ClientID": "", + "ClientSecret": "" + } + }, + "NamedOps": { + "Enabled": false, + "SessionTTL": 1800000000000, + "MaxSessions": 1000, + "MaxBundleBytes": 262144, + "Storage": { + "ProviderID": "", + "KeyPrefix": "cosmo_code_mode" + } + } + }, "GraphName": "mygraph", "ExcludeMutations": false, "EnableArbitraryOperations": false, diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 33cc0c92e6..c2dceef5dc 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -202,7 +202,44 @@ "ProviderID": "mcp" }, "Session": { - "Stateless": true + "Stateless": false + }, + "CodeMode": { + "Enabled": true, + "Server": { + "ListenAddr": "localhost:6027" + }, + "RequireMutationApproval": false, + "ExecuteTimeout": 45000000000, + "MaxResultBytes": 64000, + "Sandbox": { + "Timeout": 7000000000, + "MaxMemoryMB": 32, + "MaxInputSizeBytes": 131072, + "MaxOutputSizeBytes": 2097152 + }, + "QueryGeneration": { + "Enabled": true, + "Endpoint": "https://yoko.example.com", + "Timeout": 15000000000, + "Auth": { + "Type": "jwt", + "StaticToken": "static-token", + "TokenEndpoint": "https://auth.example.com/token", + "ClientID": "router-client", + "ClientSecret": "router-secret" + } + }, + "NamedOps": { + "Enabled": true, + "SessionTTL": 2700000000000, + "MaxSessions": 2000, + "MaxBundleBytes": 524288, + "Storage": { + "ProviderID": "my_redis", + "KeyPrefix": "custom_code_mode" + } + } }, "GraphName": "cosmo", "ExcludeMutations": false, From 16489ad2fbf156c82d42e34b8104160c497c996a Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 6 May 2026 10:25:30 +0200 Subject: [PATCH 02/10] feat(router): sanitize non-serializable code mode return values Previously a single non-serializable leaf (BigInt, function, symbol, undefined, circular ref) caused the sandbox to drop the entire return value and surface a NotSerializable error. Multi-source aggregations that put raw federation errors or upstream payloads into a debug field lost the rest of an expensive run. The validator now mutates each bad leaf to the sentinel string "<>" and records {path, kind} in a new warnings array on the response envelope. The successful result still flows through. NotSerializable now only fires when JSON.stringify itself throws after sanitization. Co-Authored-By: Claude Opus 4.7 (1M context) --- router/internal/codemode/harness/envelope.go | 16 ++-- router/internal/codemode/sandbox/execute.go | 4 +- router/internal/codemode/sandbox/sandbox.go | 9 +++ .../internal/codemode/sandbox/sandbox_test.go | 79 ++++++++++++++----- .../internal/codemode/sandbox/validation.go | 60 ++++++++------ router/internal/codemode/server/server.go | 2 +- 6 files changed, 120 insertions(+), 50 deletions(-) diff --git a/router/internal/codemode/harness/envelope.go b/router/internal/codemode/harness/envelope.go index f4bfca7171..7b09fec3ba 100644 --- a/router/internal/codemode/harness/envelope.go +++ b/router/internal/codemode/harness/envelope.go @@ -14,17 +14,20 @@ const defaultMaxResultBytes = 32 << 10 const previewBytes = 1 << 10 type ErrorEnvelope = sandbox.ErrorEnvelope +type SerializationWarning = sandbox.SerializationWarning // ResultEnvelope is the MCP-facing tool-result body for code_mode_run_js. // // Wire shape: // - result is always present (null if the agent threw). +// - warnings is omitted on the wire when empty. // - truncated is omitted on the wire when false (only signals a non-default state). // - error is omitted on the wire when nil (only present on the throw path). type ResultEnvelope struct { - Result json.RawMessage `json:"result"` - Truncated bool `json:"truncated,omitempty"` - Error *ErrorEnvelope `json:"error,omitempty"` + Result json.RawMessage `json:"result"` + Warnings []SerializationWarning `json:"warnings,omitempty"` + Truncated bool `json:"truncated,omitempty"` + Error *ErrorEnvelope `json:"error,omitempty"` } func BuildEnvelope(sandboxResult sandbox.ExecuteResult, maxResultBytes int) (ResultEnvelope, error) { @@ -34,12 +37,13 @@ func BuildEnvelope(sandboxResult sandbox.ExecuteResult, maxResultBytes int) (Res if !sandboxResult.OK { return ResultEnvelope{ Result: json.RawMessage("null"), + Warnings: sandboxResult.Warnings, Truncated: false, Error: cloneErrorEnvelope(sandboxResult.Error), }, nil } if len(sandboxResult.Result) <= maxResultBytes { - return ResultEnvelope{Result: sandboxResult.Result, Truncated: false, Error: nil}, nil + return ResultEnvelope{Result: sandboxResult.Result, Warnings: sandboxResult.Warnings, Truncated: false, Error: nil}, nil } truncated, ok, err := structurallyTruncate(sandboxResult.Result, maxResultBytes) @@ -47,13 +51,13 @@ func BuildEnvelope(sandboxResult sandbox.ExecuteResult, maxResultBytes int) (Res return ResultEnvelope{}, err } if ok { - return ResultEnvelope{Result: truncated, Truncated: true, Error: nil}, nil + return ResultEnvelope{Result: truncated, Warnings: sandboxResult.Warnings, Truncated: true, Error: nil}, nil } fallback, err := previewEnvelope(sandboxResult.Result) if err != nil { return ResultEnvelope{}, err } - return ResultEnvelope{Result: fallback, Truncated: true, Error: nil}, nil + return ResultEnvelope{Result: fallback, Warnings: sandboxResult.Warnings, Truncated: true, Error: nil}, nil } func cloneErrorEnvelope(err *ErrorEnvelope) *ErrorEnvelope { diff --git a/router/internal/codemode/sandbox/execute.go b/router/internal/codemode/sandbox/execute.go index a983730368..1a4a2f22f0 100644 --- a/router/internal/codemode/sandbox/execute.go +++ b/router/internal/codemode/sandbox/execute.go @@ -124,7 +124,7 @@ func (s *Sandbox) Execute(ctx context.Context, req ExecuteRequest) (execResult E } resultValue := value.GetPropertyStr("result") - result, validationErr, err := validateResult(qctx, resultValue, s.cfg.MaxOutputSizeBytes) + result, warnings, validationErr, err := validateResult(qctx, resultValue, s.cfg.MaxOutputSizeBytes) if err != nil { return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil } @@ -132,6 +132,7 @@ func (s *Sandbox) Execute(ctx context.Context, req ExecuteRequest) (execResult E return ExecuteResult{ OK: false, Error: validationErr, + Warnings: warnings, OutputSize: envelopeSize(nil, validationErr), HostCalls: int(state.hostCalls.Load()), }, nil @@ -139,6 +140,7 @@ func (s *Sandbox) Execute(ctx context.Context, req ExecuteRequest) (execResult E return ExecuteResult{ OK: true, Result: result, + Warnings: warnings, OutputSize: envelopeSize(result, nil), HostCalls: int(state.hostCalls.Load()), }, nil diff --git a/router/internal/codemode/sandbox/sandbox.go b/router/internal/codemode/sandbox/sandbox.go index a4d42f91f3..7866eda79d 100644 --- a/router/internal/codemode/sandbox/sandbox.go +++ b/router/internal/codemode/sandbox/sandbox.go @@ -64,6 +64,7 @@ type ExecuteResult struct { OK bool Result json.RawMessage Error *ErrorEnvelope + Warnings []SerializationWarning Truncated bool OutputSize int HostCalls int @@ -76,6 +77,14 @@ type ErrorEnvelope struct { Cause *ErrorEnvelope `json:"cause,omitempty"` } +// SerializationWarning records a non-serializable value found in the script's +// return value. The bad value is replaced in the response with the sentinel +// string "<>" where KIND matches the reported Kind. +type SerializationWarning struct { + Path string `json:"path"` + Kind string `json:"kind"` +} + type ApprovalGate interface { Decide(ctx context.Context, req ApprovalRequest) (ApprovalDecision, error) } diff --git a/router/internal/codemode/sandbox/sandbox_test.go b/router/internal/codemode/sandbox/sandbox_test.go index db69e82679..48e8e8692a 100644 --- a/router/internal/codemode/sandbox/sandbox_test.go +++ b/router/internal/codemode/sandbox/sandbox_test.go @@ -363,34 +363,77 @@ func TestExecuteMemoryLimit(t *testing.T) { assert.Equal(t, "MemoryLimit", got.Error.Name) } -func TestExecuteNotSerializableResult(t *testing.T) { +func TestExecuteSanitizesNonSerializableField(t *testing.T) { s := newTestSandbox(t, "", lookup{}, nil) got := execute(t, s, ExecuteRequest{WrappedJS: `async () => ({ x: () => 1 })`}) - assert.Equal(t, false, got.OK) - assert.Equal(t, json.RawMessage(nil), got.Result) - require.NotNil(t, got.Error) - assert.Equal(t, ErrorEnvelope{ - Name: "NotSerializable", - Message: "return value contains non-JSON-serializable values at $.x", - Stack: "", - }, *got.Error) + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`{"x":"<>"}`), got.Result) + assert.Equal(t, []SerializationWarning{{Path: "$.x", Kind: "function"}}, got.Warnings) } -func TestExecuteNotSerializableProducesErrorEnvelope(t *testing.T) { +func TestExecuteSanitizesMixedNonSerializableValues(t *testing.T) { s := newTestSandbox(t, "", lookup{}, nil) got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { return { x: () => 1, y: 5n, cycle: (() => { const o = {}; o.self = o; return o; })() }; }`}) - assert.Equal(t, false, got.OK) - assert.Equal(t, json.RawMessage(nil), got.Result) - require.NotNil(t, got.Error) - assert.Equal(t, ErrorEnvelope{ - Name: "NotSerializable", - Message: "return value contains non-JSON-serializable values at $.x, $.y, $.cycle.self", - Stack: "", - }, *got.Error) + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`{"x":"<>","y":"<>","cycle":{"self":"<>"}}`), got.Result) + assert.Equal(t, []SerializationWarning{ + {Path: "$.x", Kind: "function"}, + {Path: "$.y", Kind: "bigint"}, + {Path: "$.cycle.self", Kind: "cycle"}, + }, got.Warnings) +} + +func TestExecuteSanitizesRootBigInt(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => 5n`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`"<>"`), got.Result) + assert.Equal(t, []SerializationWarning{{Path: "$", Kind: "bigint"}}, got.Warnings) +} + +func TestExecuteSanitizesRootUndefined(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => undefined`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`"<>"`), got.Result) + assert.Equal(t, []SerializationWarning{{Path: "$", Kind: "undefined"}}, got.Warnings) +} + +func TestExecuteSanitizesNonSerializableInArray(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => [1, undefined, () => 2]`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`[1,"<>","<>"]`), got.Result) + assert.Equal(t, []SerializationWarning{ + {Path: "$[1]", Kind: "undefined"}, + {Path: "$[2]", Kind: "function"}, + }, got.Warnings) +} + +func TestExecuteCleanResultProducesNoWarnings(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => ({ ok: true, n: 1, items: [1, 2, 3] })`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`{"ok":true,"n":1,"items":[1,2,3]}`), got.Result) + assert.Equal(t, []SerializationWarning(nil), got.Warnings) } func TestExecuteOutputTooLarge(t *testing.T) { diff --git a/router/internal/codemode/sandbox/validation.go b/router/internal/codemode/sandbox/validation.go index 8353bc8790..dae3ca74f3 100644 --- a/router/internal/codemode/sandbox/validation.go +++ b/router/internal/codemode/sandbox/validation.go @@ -3,7 +3,6 @@ package sandbox import ( "encoding/json" "fmt" - "strings" "github.com/fastschema/qjs" ) @@ -22,47 +21,54 @@ globalThis.__codemodeNormalizeError = (err, depth = 0) => { globalThis.__codemodeNormalizeErrorJSON = (err) => JSON.stringify(__codemodeNormalizeError(err)); globalThis.__codemodeValidateResult = (value) => { - const bad = []; + const warnings = []; const seen = new WeakSet(); const keyPath = (base, key) => { if (typeof key === "number") return base + "[" + key + "]"; return /^[A-Za-z_$][A-Za-z0-9_$]*$/.test(key) ? base + "." + key : base + "[" + JSON.stringify(key) + "]"; }; - const walk = (v, path) => { + const sentinel = (kind) => "<>"; + const walk = (v, path, parent, key) => { const t = typeof v; if (t === "bigint" || t === "function" || t === "symbol" || t === "undefined") { - bad.push(path); + parent[key] = sentinel(t); + warnings.push({ path, kind: t }); return; } if (v && t === "object") { if (seen.has(v)) { - bad.push(path); + parent[key] = sentinel("cycle"); + warnings.push({ path, kind: "cycle" }); return; } seen.add(v); if (Array.isArray(v)) { - for (let i = 0; i < v.length; i++) walk(v[i], keyPath(path, i)); + for (let i = 0; i < v.length; i++) walk(v[i], keyPath(path, i), v, i); return; } - for (const k of Object.keys(v)) walk(v[k], keyPath(path, k)); + for (const k of Object.keys(v)) walk(v[k], keyPath(path, k), v, k); } }; - walk(value, "$"); - if (bad.length) return JSON.stringify({ serializable: false, paths: bad }); + const root = { value }; + walk(root.value, "$", root, "value"); try { - const json = JSON.stringify(value); - if (json === undefined) return JSON.stringify({ serializable: false, paths: ["$"] }); - return JSON.stringify({ serializable: true, json }); + const json = JSON.stringify(root.value); + if (json === undefined) { + return JSON.stringify({ ok: false, warnings, error: "value serialized to undefined" }); + } + return JSON.stringify({ ok: true, json, warnings }); } catch (err) { - return JSON.stringify({ serializable: false, paths: ["$"] }); + const msg = err && err.message ? String(err.message) : String(err); + return JSON.stringify({ ok: false, warnings, error: msg }); } }; ` type validationOutcome struct { - Serializable bool `json:"serializable"` - JSON string `json:"json"` - Paths []string `json:"paths"` + OK bool `json:"ok"` + JSON string `json:"json"` + Warnings []SerializationWarning `json:"warnings"` + Error string `json:"error"` } func installValidationHelpers(ctx *qjs.Context) error { @@ -71,28 +77,34 @@ func installValidationHelpers(ctx *qjs.Context) error { return err } -func validateResult(ctx *qjs.Context, result *qjs.Value, maxOutputBytes int) (json.RawMessage, *ErrorEnvelope, error) { +func validateResult(ctx *qjs.Context, result *qjs.Value, maxOutputBytes int) (json.RawMessage, []SerializationWarning, *ErrorEnvelope, error) { global := ctx.Global() validator := global.GetPropertyStr("__codemodeValidateResult") encoded, err := ctx.Invoke(validator, global, result) if err != nil { - return nil, nil, err + return nil, nil, nil, err } var outcome validationOutcome if err := json.Unmarshal([]byte(encoded.String()), &outcome); err != nil { - return nil, nil, err + return nil, nil, nil, err + } + if len(outcome.Warnings) == 0 { + outcome.Warnings = nil } - if !outcome.Serializable { - message := "return value contains non-JSON-serializable values at " + strings.Join(outcome.Paths, ", ") - return nil, &ErrorEnvelope{Name: "NotSerializable", Message: message, Stack: ""}, nil + if !outcome.OK { + message := "JSON serialization failed after sanitization" + if outcome.Error != "" { + message = message + ": " + outcome.Error + } + return nil, outcome.Warnings, &ErrorEnvelope{Name: "NotSerializable", Message: message, Stack: ""}, nil } if len(outcome.JSON) > maxOutputBytes { - return nil, &ErrorEnvelope{ + return nil, outcome.Warnings, &ErrorEnvelope{ Name: "OutputTooLarge", Message: fmt.Sprintf("encoded result size %d bytes exceeds limit %d bytes", len(outcome.JSON), maxOutputBytes), Stack: "", }, nil } - return json.RawMessage(outcome.JSON), nil, nil + return json.RawMessage(outcome.JSON), outcome.Warnings, nil, nil } diff --git a/router/internal/codemode/server/server.go b/router/internal/codemode/server/server.go index 17fbffa32a..c669a7f5e6 100644 --- a/router/internal/codemode/server/server.go +++ b/router/internal/codemode/server/server.go @@ -40,7 +40,7 @@ const searchAPIDescription = "Plan ALL data shapes you need up front, then call const executeAPISourceDescription = "JavaScript source containing a single async arrow function. The host wraps it as `()()` and awaits the resulting Promise; the resolved JSON-serializable value is the tool result." -const executeAPIDescription = "Run JavaScript source as a single async arrow function in the Code Mode sandbox. Use `await tools.(vars)` for operations registered by code_mode_search_tools; the cumulative tools namespace is available at `yoko://persisted-ops.d.ts`.\n\nStyle: write compact source — single line if it fits, no // comments, no blank lines, short variable names. The JSON wrapping that encodes your source charges you for every newline and indent space.\n\nBatch everything into ONE code_mode_run_js call. ≥3 `tools.*` invocations per call is normal; over-fetch and decide in JS, don't round-trip. A failing inner call degrades the result, not the whole script — wrap with try/catch and surface the error in the return value.\n\nThe return value of your async arrow is the only output channel — `console` is not available. To surface intermediate state, include it in the returned object (e.g. `return { result, debug: { ... } }`). For resilient fan-out use `Promise.allSettled` — `Promise.all` rejects on first failure and discards partial results. Up to 256 `tools.*` invocations per call. Return values must be JSON-serializable; `BigInt`, functions, symbols, and circular refs throw `NotSerializable`.\n\nExample: `async()=>{const o=await tools.getOrders({customerId:\"c_1\"});if(o.errors?.length)throw new Error(o.errors[0].message);return o.data.orders;}`\n\nType declarations for reference (consumed via `yoko://persisted-ops.d.ts`):\n\n```ts\ntype GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\ntype R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n\ndeclare const tools: {};\n\ndeclare function notNull(value: T | null | undefined, message?: string): T;\ndeclare function compact(value: T): T;\n```" +const executeAPIDescription = "Run JavaScript source as a single async arrow function in the Code Mode sandbox. Use `await tools.(vars)` for operations registered by code_mode_search_tools; the cumulative tools namespace is available at `yoko://persisted-ops.d.ts`.\n\nStyle: write compact source — single line if it fits, no // comments, no blank lines, short variable names. The JSON wrapping that encodes your source charges you for every newline and indent space.\n\nBatch everything into ONE code_mode_run_js call. ≥3 `tools.*` invocations per call is normal; over-fetch and decide in JS, don't round-trip. A failing inner call degrades the result, not the whole script — wrap with try/catch and surface the error in the return value.\n\nThe return value of your async arrow is the only output channel — `console` is not available. To surface intermediate state, include it in the returned object (e.g. `return { result, debug: { ... } }`). For resilient fan-out use `Promise.allSettled` — `Promise.all` rejects on first failure and discards partial results. Up to 256 `tools.*` invocations per call. Non-serializable leaves in the return value (`BigInt`, functions, symbols, `undefined`, circular refs) are replaced with the sentinel string `<>` and listed in the response's `warnings: [{path, kind}]` field; the rest of the value still comes through.\n\nExample: `async()=>{const o=await tools.getOrders({customerId:\"c_1\"});if(o.errors?.length)throw new Error(o.errors[0].message);return o.data.orders;}`\n\nType declarations for reference (consumed via `yoko://persisted-ops.d.ts`):\n\n```ts\ntype GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\ntype R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n\ndeclare const tools: {};\n\ndeclare function notNull(value: T | null | undefined, message?: string): T;\ndeclare function compact(value: T): T;\n```" // Config configures the Code Mode MCP server. type Config struct { From a24711973bd5ca0342159afc2d1632ce744cc25e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 6 May 2026 23:17:03 +0200 Subject: [PATCH 03/10] refactor(router): extract Code Mode MCP descriptions to embedded markdown Moves the long inline tool/resource description string constants out of server.go into a new internal/codemode/server/descriptions sub-package, where each description lives in its own .md file embedded via go:embed. Lets the prose be edited as readable markdown without touching Go source. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../server/descriptions/descriptions.go | 37 +++++++++ .../server/descriptions/execute_source.md | 3 + .../server/descriptions/execute_tool.md | 32 ++++++++ .../descriptions/persisted_ops_resource.md | 1 + .../server/descriptions/search_tool.md | 38 +++++++++ router/internal/codemode/server/server.go | 52 +++++++++--- .../internal/codemode/server/server_test.go | 81 ++++++++++++++++++- 7 files changed, 231 insertions(+), 13 deletions(-) create mode 100644 router/internal/codemode/server/descriptions/descriptions.go create mode 100644 router/internal/codemode/server/descriptions/execute_source.md create mode 100644 router/internal/codemode/server/descriptions/execute_tool.md create mode 100644 router/internal/codemode/server/descriptions/persisted_ops_resource.md create mode 100644 router/internal/codemode/server/descriptions/search_tool.md diff --git a/router/internal/codemode/server/descriptions/descriptions.go b/router/internal/codemode/server/descriptions/descriptions.go new file mode 100644 index 0000000000..3336fe7c77 --- /dev/null +++ b/router/internal/codemode/server/descriptions/descriptions.go @@ -0,0 +1,37 @@ +// Package descriptions holds the markdown text used as MCP server, tool, and +// resource descriptions for the Code Mode server. Each description lives in its +// own .md file and is embedded at compile time so prose can be edited without +// touching Go source. go:embed only supports vars (not consts), so each export +// is a package-level string treated as immutable. +package descriptions + +import ( + _ "embed" + "strings" +) + +//go:embed search_tool.md +var rawSearchTool string + +//go:embed execute_tool.md +var rawExecuteTool string + +//go:embed execute_source.md +var rawExecuteSource string + +//go:embed persisted_ops_resource.md +var rawPersistedOpsResource string + +// SearchTool is the description of the code_mode_search_tools MCP tool. +var SearchTool = strings.TrimRight(rawSearchTool, "\n") + +// ExecuteTool is the description of the code_mode_run_js MCP tool. +var ExecuteTool = strings.TrimRight(rawExecuteTool, "\n") + +// ExecuteSource is the description of the `source` input parameter of the +// code_mode_run_js MCP tool. +var ExecuteSource = strings.TrimRight(rawExecuteSource, "\n") + +// PersistedOpsResource is the description of the yoko://persisted-ops.d.ts MCP +// resource. +var PersistedOpsResource = strings.TrimRight(rawPersistedOpsResource, "\n") diff --git a/router/internal/codemode/server/descriptions/execute_source.md b/router/internal/codemode/server/descriptions/execute_source.md new file mode 100644 index 0000000000..178814932a --- /dev/null +++ b/router/internal/codemode/server/descriptions/execute_source.md @@ -0,0 +1,3 @@ +JavaScript source containing a single async arrow function. +The host wraps it as `()()` and awaits the resulting Promise; +the resolved JSON-serializable value is the tool result. \ No newline at end of file diff --git a/router/internal/codemode/server/descriptions/execute_tool.md b/router/internal/codemode/server/descriptions/execute_tool.md new file mode 100644 index 0000000000..28646da1e1 --- /dev/null +++ b/router/internal/codemode/server/descriptions/execute_tool.md @@ -0,0 +1,32 @@ +Run JavaScript source as a single async arrow function in the Code Mode sandbox. +Use `await tools.(vars)` for operations registered by code_mode_search_tools; +the cumulative tools namespace is available at `yoko://persisted-ops.d.ts`. + +Style: write compact source — single line if it fits, no // comments, no blank lines, short variable names. +The JSON wrapping that encodes your source charges you for every newline and indent space. + +Batch everything into ONE code_mode_run_js call. +≥3 `tools.*` invocations per call is normal; +over-fetch and decide in JS, don't round-trip. +A failing inner call degrades the result, not the whole script — wrap with try/catch and surface the error in the return value. + +The return value of your async arrow is the only output channel — `console` is not available. +To surface intermediate state, include it in the returned object (e.g. `return { result, debug: { ... } }`). +For resilient fan-out use `Promise.allSettled` — `Promise.all` rejects on first failure and discards partial results. +Up to 256 `tools.*` invocations per call. +Non-serializable leaves in the return value (`BigInt`, functions, symbols, `undefined`, circular refs) are replaced with the sentinel string `<>` and listed in the response's `warnings: [{path, kind}]` field; +the rest of the value still comes through. + +Example: `async()=>{const o=await tools.getOrders({customerId:"c_1"});if(o.errors?.length)throw new Error(o.errors[0].message);return o.data.orders;}` + +Type declarations for reference (consumed via `yoko://persisted-ops.d.ts`): + +```ts +type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; + +declare const tools: {}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T; +``` \ No newline at end of file diff --git a/router/internal/codemode/server/descriptions/persisted_ops_resource.md b/router/internal/codemode/server/descriptions/persisted_ops_resource.md new file mode 100644 index 0000000000..2e119392d6 --- /dev/null +++ b/router/internal/codemode/server/descriptions/persisted_ops_resource.md @@ -0,0 +1 @@ +Cumulative TypeScript definitions for the current Code Mode MCP session's named operations. \ No newline at end of file diff --git a/router/internal/codemode/server/descriptions/search_tool.md b/router/internal/codemode/server/descriptions/search_tool.md new file mode 100644 index 0000000000..64581a38e1 --- /dev/null +++ b/router/internal/codemode/server/descriptions/search_tool.md @@ -0,0 +1,38 @@ +Plan ALL data shapes you need up front, +then call ONCE with every prompt in a single batch. +Each extra search is a round-trip you pay for. + +DEFAULT TO ONE PROMPT. +If the entities are related in any way — same domain, joinable, fetched together to answer one question, +traversed via the same parent, or the user mentioned them in the same breath — combine them into a SINGLE prompt that describes the complete joined shape. +Multiple prompts should be the exception, not the default. + +Write each prompt as the COMPLETE final shape of data you want, including joins and correlation IDs. + +Write prompts in a graph-like shape with relationships and nesting, not as separate flat queries. + +BE PRECISE about what you need. +Vague prompts produce vague operations and force re-searches. +Always state: +- The exact fields you need on each entity ("id, forename, surname" — not "name info"). +- Any required filters/arguments but never specific values ("employee by id - not "employee 123", "employee filtered by department name" - not "employee in department 'Engineering'"). +- Concrete entity and relationship names from the domain when you know them; otherwise describe the relationship explicitly ("the team an employee belongs to"). + +When to use multiple prompts (rare): genuinely unrelated operations on disjoint domains, different argument shapes that can't share a parent, or queries vs mutations. +Never slice one joinable shape into fragments. +When in doubt, combine. + +Do NOT issue prompts for derived/computed values: averages, medians, counts, filters, exclusions ("without X"), sorting, top-N. +Fetch the raw rows once and compute in code_mode_run_js. +Yoko exposes data; arithmetic and reshaping happen in your JS. + +Anti-pattern: search → inspect result → notice a field or ID is missing → search again. +One well-formed prompt beats three round-trips. + +The response appends newly registered TypeScript declarations for use as `await tools.(vars)` inside code_mode_run_js; +the cumulative bundle is available at `yoko://persisted-ops.d.ts`. + +Good example: "employee filtered by id with fields id, forename, surname, role, startDate; their team with fields id, name and the team's department with fields id, name; the projects the employee is assigned to with fields id, title, status, dueDate and each project's owner (employee) with fields id, forename, surname" + +Bad examples: ["list of employees with name info", "team for employee 123", "projects in department 'Engineering'", "top 5 employees by project count", "average project duration per team"] +— five prompts instead of one joined shape, vague fields ("name info"), hardcoded filter values ("123", "'Engineering'"), and derived/computed results (top-N, average) that belong in code_mode_run_js, not in a search prompt. \ No newline at end of file diff --git a/router/internal/codemode/server/server.go b/router/internal/codemode/server/server.go index c669a7f5e6..c4c0eaee0f 100644 --- a/router/internal/codemode/server/server.go +++ b/router/internal/codemode/server/server.go @@ -14,6 +14,7 @@ import ( "github.com/wundergraph/cosmo/router/internal/codemode/harness" "github.com/wundergraph/cosmo/router/internal/codemode/observability" "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/server/descriptions" "github.com/wundergraph/cosmo/router/internal/codemode/storage" "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" "github.com/wundergraph/cosmo/router/internal/codemode/yoko" @@ -36,12 +37,6 @@ const ( namedOpsDisabledMessage = "named operations are disabled" ) -const searchAPIDescription = "Plan ALL data shapes you need up front, then call ONCE with every prompt in a single batch. Each extra search is a round-trip you pay for.\n\nDEFAULT TO ONE PROMPT. If the entities are related in any way — same domain, joinable, fetched together to answer one question, traversed via the same parent, or the user mentioned them in the same breath — combine them into a SINGLE prompt that describes the complete joined shape. Multiple prompts should be the exception, not the default.\n\nWrite each prompt as the COMPLETE final shape of data you want, including joins and correlation IDs. Yoko writes GraphQL across federated subgraphs, so a single prompt like \"employees with id, first name, last name, and their pets (name, type)\" returns one joined operation — never split this into \"list employees\" + \"list pets with owner\" that you'd then have to correlate in JS. If you DO split, every prompt MUST include the join keys (IDs / foreign keys) needed to correlate the results — otherwise the operations come back un-joinable and you'll have to search again.\n\nBE PRECISE about what you need. Vague prompts produce vague operations and force re-searches. Always state:\n- The exact fields you need on each entity (\"id, forename, surname\" — not \"name info\").\n- The relationships to traverse and how deep (\"employees with their pets and each pet's owner's department\").\n- Any required filters/arguments and the values or variable names (\"by id=42\", \"where status=ACTIVE\", \"limit 50\").\n- The shape of nested/related entities, field by field — do not say \"with all their data\".\n- Concrete entity and relationship names from the domain when you know them; otherwise describe the relationship explicitly (\"the team an employee belongs to\").\nA precise prompt: \"employee by id (variable: $id) returning id, forename, surname, role, and pets { name, type, age }\". A vague prompt: \"get employee details with related stuff\" — this will come back missing fields you need.\n\nWhen to use multiple prompts (rare): genuinely unrelated operations on disjoint domains, different argument shapes that can't share a parent, or queries vs mutations. Never slice one joinable shape into fragments. When in doubt, combine.\n\nDo NOT issue prompts for derived/computed values: averages, medians, counts, filters, exclusions (\"without X\"), sorting, top-N. Fetch the raw rows once and compute in code_mode_run_js. Yoko exposes data; arithmetic and reshaping happen in your JS.\n\nAnti-pattern: search → inspect result → notice a field or ID is missing → search again. One well-formed prompt beats three round-trips.\n\nThe response appends newly registered TypeScript declarations for use as `await tools.(vars)` inside code_mode_run_js; the cumulative bundle is available at `yoko://persisted-ops.d.ts`." - -const executeAPISourceDescription = "JavaScript source containing a single async arrow function. The host wraps it as `()()` and awaits the resulting Promise; the resolved JSON-serializable value is the tool result." - -const executeAPIDescription = "Run JavaScript source as a single async arrow function in the Code Mode sandbox. Use `await tools.(vars)` for operations registered by code_mode_search_tools; the cumulative tools namespace is available at `yoko://persisted-ops.d.ts`.\n\nStyle: write compact source — single line if it fits, no // comments, no blank lines, short variable names. The JSON wrapping that encodes your source charges you for every newline and indent space.\n\nBatch everything into ONE code_mode_run_js call. ≥3 `tools.*` invocations per call is normal; over-fetch and decide in JS, don't round-trip. A failing inner call degrades the result, not the whole script — wrap with try/catch and surface the error in the return value.\n\nThe return value of your async arrow is the only output channel — `console` is not available. To surface intermediate state, include it in the returned object (e.g. `return { result, debug: { ... } }`). For resilient fan-out use `Promise.allSettled` — `Promise.all` rejects on first failure and discards partial results. Up to 256 `tools.*` invocations per call. Non-serializable leaves in the return value (`BigInt`, functions, symbols, `undefined`, circular refs) are replaced with the sentinel string `<>` and listed in the response's `warnings: [{path, kind}]` field; the rest of the value still comes through.\n\nExample: `async()=>{const o=await tools.getOrders({customerId:\"c_1\"});if(o.errors?.length)throw new Error(o.errors[0].message);return o.data.orders;}`\n\nType declarations for reference (consumed via `yoko://persisted-ops.d.ts`):\n\n```ts\ntype GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\ntype R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n\ndeclare const tools: {};\n\ndeclare function notNull(value: T | null | undefined, message?: string): T;\ndeclare function compact(value: T): T;\n```" - // Config configures the Code Mode MCP server. type Config struct { ListenAddr string @@ -254,6 +249,43 @@ func (s *Server) Reload(schema *ast.Document, sdl string) error { } if s.yokoClient != nil { s.yokoClient.SetSchema(sdl) + // Eagerly index the new SDL in the background so the first user-facing + // code_mode_search_tools call doesn't pay the IndexSchema round-trip + // latency. Failures are logged and ignored — the lazy path inside + // Search will retry on the next call. + // + // recover guard: an unrecovered panic here would bring the whole + // router down because the goroutine runs outside any caller frame. + // The warm-up is best-effort, so a panic must never escape. + if sdl != "" { + yokoClient := s.yokoClient + logger := s.logger + sdlBytes := len(sdl) + go func() { + start := time.Now() + defer func() { + if r := recover(); r != nil { + logger.Error("code mode eager schema index panicked", + zap.Any("panic", r), + zap.Duration("duration", time.Since(start)), + ) + } + }() + logger.Info("code mode eager schema index started", + zap.Int("sdl_bytes", sdlBytes), + ) + if err := yokoClient.EnsureIndexed(context.Background()); err != nil { + logger.Warn("code mode eager schema index failed", + zap.Error(err), + zap.Duration("duration", time.Since(start)), + ) + return + } + logger.Info("code mode eager schema index completed", + zap.Duration("duration", time.Since(start)), + ) + }() + } } if s.sessionStateless && s.namedOpsEnabled { s.warnStatelessNamedOpsOnce() @@ -265,13 +297,13 @@ func (s *Server) Reload(schema *ast.Document, sdl string) error { func (s *Server) registerTools() { s.mcpServer.AddTool(&mcp.Tool{ Name: "code_mode_search_tools", - Description: searchAPIDescription, + Description: descriptions.SearchTool, InputSchema: searchAPIInputSchema(), }, s.handleSearch) s.mcpServer.AddTool(&mcp.Tool{ Name: "code_mode_run_js", - Description: executeAPIDescription, + Description: descriptions.ExecuteTool, InputSchema: executeAPIInputSchema(), }, s.handleExecute) } @@ -281,7 +313,7 @@ func (s *Server) registerPersistedOpsResource() { URI: persistedOpsURI, Name: "persisted-ops.d.ts", Title: "Persisted operations TypeScript definitions", - Description: "Cumulative TypeScript definitions for the current Code Mode MCP session's named operations.", + Description: descriptions.PersistedOpsResource, MIMEType: "text/plain", }, s.handlePersistedOpsResource) } @@ -430,7 +462,7 @@ func executeAPIInputSchema() map[string]any { "source": map[string]any{ "type": "string", "minLength": 1, - "description": executeAPISourceDescription, + "description": descriptions.ExecuteSource, }, }, } diff --git a/router/internal/codemode/server/server_test.go b/router/internal/codemode/server/server_test.go index 65f153dfa7..ed7151c5c6 100644 --- a/router/internal/codemode/server/server_test.go +++ b/router/internal/codemode/server/server_test.go @@ -14,6 +14,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/server/descriptions" "github.com/wundergraph/cosmo/router/internal/codemode/storage" "github.com/wundergraph/cosmo/router/internal/codemode/yoko" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" @@ -71,7 +72,7 @@ func TestListToolsReturnsCodeModeTools(t *testing.T) { assert.Equal(t, mustJSON(t, []*mcp.Tool{ { Name: "code_mode_run_js", - Description: executeAPIDescription, + Description: descriptions.ExecuteTool, InputSchema: map[string]any{ "type": "object", "required": []any{"source"}, @@ -79,14 +80,14 @@ func TestListToolsReturnsCodeModeTools(t *testing.T) { "source": map[string]any{ "type": "string", "minLength": float64(1), - "description": executeAPISourceDescription, + "description": descriptions.ExecuteSource, }, }, }, }, { Name: "code_mode_search_tools", - Description: searchAPIDescription, + Description: descriptions.SearchTool, InputSchema: map[string]any{ "type": "object", "required": []any{"prompts"}, @@ -272,6 +273,80 @@ func TestReloadForwardsSchemaAndSDL(t *testing.T) { assert.Equal(t, "schema { query: Query }", client.Schema()) } +func TestReloadEagerlyIndexesViaBackgroundGoroutine(t *testing.T) { + core, recorded := observer.New(zap.InfoLevel) + searcher := newFakeYoko() + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Storage: newRecordingStorage(), + YokoClient: searcher, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.New(core), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + + require.Eventually(t, func() bool { + return searcher.ensureIndexedCallCount() == 1 + }, 2*time.Second, 5*time.Millisecond, "eager index should fire once after Reload") + + require.Eventually(t, func() bool { + return recorded.FilterMessage("code mode eager schema index started").Len() == 1 && + recorded.FilterMessage("code mode eager schema index completed").Len() == 1 + }, 2*time.Second, 5*time.Millisecond, "expected start+completed info logs") +} + +func TestReloadEagerIndexLogsWarnOnFailure(t *testing.T) { + core, recorded := observer.New(zap.InfoLevel) + searcher := newFakeYoko() + searcher.ensureIndexedErr = errors.New("yoko unreachable") + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Storage: newRecordingStorage(), + YokoClient: searcher, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.New(core), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + + require.Eventually(t, func() bool { + return recorded.FilterMessage("code mode eager schema index started").Len() == 1 && + recorded.FilterMessage("code mode eager schema index failed").Len() == 1 && + recorded.FilterMessage("code mode eager schema index completed").Len() == 0 + }, 2*time.Second, 5*time.Millisecond, "expected start+failed logs without completed log") +} + +func TestReloadEagerIndexSkippedWhenSDLEmpty(t *testing.T) { + searcher := newFakeYoko() + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Storage: newRecordingStorage(), + YokoClient: searcher, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "")) + + // Give the goroutine that EnsureIndexed *would* have launched a chance to + // run; assert it never did. + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, searcher.ensureIndexedCallCount()) +} + func TestReloadDisabledIsNoOp(t *testing.T) { store := newRecordingStorage() client := yoko.New(nil, "http://127.0.0.1", zap.NewNop()) From 1448ef0f6fc2667aba8bbded43e9bcc954622367 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 7 May 2026 10:57:44 +0200 Subject: [PATCH 04/10] feat(code-mode): auto-reconnect upstream from MCP stdio proxy The proxy now retries the initial dial with exponential backoff and supervises the live session via session.Wait(); when the upstream disconnects it reconnects transparently so downstream clients (Claude Desktop, etc.) keep working across router restarts. KeepAlive pings make dead connections surface quickly instead of waiting for the next call. Co-Authored-By: Claude Opus 4.7 (1M context) --- demo/code-mode/mcp-stdio-proxy/main.go | 208 +++++++++++++++++++- demo/code-mode/mcp-stdio-proxy/main_test.go | 88 +++++++++ 2 files changed, 286 insertions(+), 10 deletions(-) diff --git a/demo/code-mode/mcp-stdio-proxy/main.go b/demo/code-mode/mcp-stdio-proxy/main.go index 03ab9df8aa..6f172d100f 100644 --- a/demo/code-mode/mcp-stdio-proxy/main.go +++ b/demo/code-mode/mcp-stdio-proxy/main.go @@ -9,8 +9,10 @@ import ( "net/http" "os" "os/signal" + "sync" "sync/atomic" "syscall" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -21,12 +23,22 @@ const ( defaultUpstreamURL = "http://127.0.0.1:5027/mcp" proxyName = "yoko-stdio-proxy" proxyVersion = "0.1.0" + + initialReconnectBackoff = 500 * time.Millisecond + maxReconnectBackoff = 30 * time.Second + upstreamKeepAlive = 30 * time.Second ) type proxyOptions struct { upstreamURL string transport mcp.Transport httpClient *http.Client + // keepAlive overrides the upstream client KeepAlive interval. Zero uses the + // default. Tests use a short interval so disconnects are detected quickly. + keepAlive time.Duration + // initialBackoff overrides the initial reconnect backoff. Zero uses the + // default. Tests use a short value to keep reconnect latency low. + initialBackoff time.Duration } func main() { @@ -66,11 +78,20 @@ func runProxy(ctx context.Context, opts proxyOptions) error { if opts.transport == nil { opts.transport = &mcp.StdioTransport{} } + keepAlive := opts.keepAlive + if keepAlive == 0 { + keepAlive = upstreamKeepAlive + } + initialBackoff := opts.initialBackoff + if initialBackoff == 0 { + initialBackoff = initialReconnectBackoff + } var localSession atomic.Pointer[mcp.ServerSession] upstreamClient := mcp.NewClient( &mcp.Implementation{Name: proxyName, Version: proxyVersion}, &mcp.ClientOptions{ + KeepAlive: keepAlive, ElicitationHandler: func(ctx context.Context, req *mcp.ElicitRequest) (*mcp.ElicitResult, error) { ss := localSession.Load() if ss == nil { @@ -81,28 +102,52 @@ func runProxy(ctx context.Context, opts proxyOptions) error { }, ) - upstreamSession, err := upstreamClient.Connect(ctx, &mcp.StreamableClientTransport{ - Endpoint: opts.upstreamURL, - HTTPClient: opts.httpClient, - }, nil) + upstream := &upstreamConn{ + client: upstreamClient, + upstreamURL: opts.upstreamURL, + httpClient: opts.httpClient, + initialBackoff: initialBackoff, + ready: make(chan struct{}), + } + + initialSession, err := upstream.connectWithRetry(ctx, "upstream connect") if err != nil { + if errors.Is(err, context.Canceled) { + return err + } return fmt.Errorf("connect upstream %q failed: %w; is the demo running? try `make code-mode-demo`", opts.upstreamURL, err) } + upstream.setSession(initialSession) + defer func() { - if err := upstreamSession.Close(); err != nil { - log.Printf("mcp-stdio-proxy: upstream close failed: %v", err) + if s := upstream.currentSession(); s != nil { + if err := s.Close(); err != nil { + log.Printf("mcp-stdio-proxy: upstream close failed: %v", err) + } } }() - toolsResp, err := upstreamSession.ListTools(ctx, &mcp.ListToolsParams{}) + toolsResp, err := initialSession.ListTools(ctx, &mcp.ListToolsParams{}) if err != nil { return fmt.Errorf("list upstream tools: %w", err) } - resourcesResp, err := upstreamSession.ListResources(ctx, &mcp.ListResourcesParams{}) + resourcesResp, err := initialSession.ListResources(ctx, &mcp.ListResourcesParams{}) if err != nil { return fmt.Errorf("list upstream resources: %w", err) } + supervisorCtx, cancelSupervisor := context.WithCancel(ctx) + defer cancelSupervisor() + supervisorDone := make(chan struct{}) + go func() { + defer close(supervisorDone) + upstream.supervise(supervisorCtx, initialSession) + }() + defer func() { + cancelSupervisor() + <-supervisorDone + }() + localServer := mcp.NewServer( &mcp.Implementation{Name: "yoko (via stdio-proxy)", Version: proxyVersion}, &mcp.ServerOptions{ @@ -130,7 +175,13 @@ func runProxy(ctx context.Context, opts proxyOptions) error { tool.InputSchema = map[string]any{"type": "object"} } localServer.AddTool(&tool, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { - result, err := upstreamSession.CallTool(ctx, &mcp.CallToolParams{ + session, err := upstream.awaitSession(ctx) + if err != nil { + var errResult mcp.CallToolResult + errResult.SetError(fmt.Errorf("upstream tool %q unavailable: %w", req.Params.Name, err)) + return &errResult, nil + } + result, err := session.CallTool(ctx, &mcp.CallToolParams{ Meta: req.Params.Meta, Name: req.Params.Name, Arguments: req.Params.Arguments, @@ -147,7 +198,11 @@ func runProxy(ctx context.Context, opts proxyOptions) error { for _, upstreamResource := range resourcesResp.Resources { resource := *upstreamResource localServer.AddResource(&resource, func(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { - result, err := upstreamSession.ReadResource(ctx, req.Params) + session, err := upstream.awaitSession(ctx) + if err != nil { + return nil, fmt.Errorf("upstream resource %q unavailable: %w", req.Params.URI, err) + } + result, err := session.ReadResource(ctx, req.Params) if err != nil { return nil, fmt.Errorf("upstream resource %q failed: %w", req.Params.URI, err) } @@ -160,3 +215,136 @@ func runProxy(ctx context.Context, opts proxyOptions) error { } return nil } + +// upstreamConn keeps a live MCP client session to the upstream router, dialing +// initially with backoff and reconnecting transparently when the session drops. +type upstreamConn struct { + client *mcp.Client + upstreamURL string + httpClient *http.Client + initialBackoff time.Duration + + mu sync.Mutex + session *mcp.ClientSession + ready chan struct{} +} + +func (u *upstreamConn) dial(ctx context.Context) (*mcp.ClientSession, error) { + return u.client.Connect(ctx, &mcp.StreamableClientTransport{ + Endpoint: u.upstreamURL, + HTTPClient: u.httpClient, + }, nil) +} + +// connectWithRetry dials the upstream, retrying with exponential backoff until +// the context is cancelled. +func (u *upstreamConn) connectWithRetry(ctx context.Context, label string) (*mcp.ClientSession, error) { + backoff := u.initialBackoff + if backoff == 0 { + backoff = initialReconnectBackoff + } + for attempt := 1; ; attempt++ { + s, err := u.dial(ctx) + if err == nil { + if attempt > 1 { + log.Printf("mcp-stdio-proxy: %s succeeded on attempt %d", label, attempt) + } + return s, nil + } + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, ctxErr + } + log.Printf("mcp-stdio-proxy: %s attempt %d failed: %v; retrying in %s", label, attempt, err, backoff) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(backoff): + } + if backoff < maxReconnectBackoff { + backoff *= 2 + if backoff > maxReconnectBackoff { + backoff = maxReconnectBackoff + } + } + } +} + +// supervise watches the active upstream session and reconnects when it drops. +// Returns when ctx is cancelled. +func (u *upstreamConn) supervise(ctx context.Context, initial *mcp.ClientSession) { + cur := initial + for { + waitDone := make(chan struct{}) + go func(s *mcp.ClientSession) { + _ = s.Wait() + close(waitDone) + }(cur) + + select { + case <-ctx.Done(): + return + case <-waitDone: + } + if ctx.Err() != nil { + return + } + + log.Printf("mcp-stdio-proxy: upstream session closed; reconnecting...") + u.markUnready() + + next, err := u.connectWithRetry(ctx, "upstream reconnect") + if err != nil { + return + } + u.setSession(next) + log.Printf("mcp-stdio-proxy: upstream reconnected") + cur = next + } +} + +// awaitSession returns the current upstream session, blocking until one is +// available or ctx is cancelled. +func (u *upstreamConn) awaitSession(ctx context.Context) (*mcp.ClientSession, error) { + for { + u.mu.Lock() + if u.session != nil { + s := u.session + u.mu.Unlock() + return s, nil + } + ready := u.ready + u.mu.Unlock() + select { + case <-ready: + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +// currentSession returns the current session without blocking. Used at shutdown +// to close whatever session is live. +func (u *upstreamConn) currentSession() *mcp.ClientSession { + u.mu.Lock() + defer u.mu.Unlock() + return u.session +} + +func (u *upstreamConn) setSession(s *mcp.ClientSession) { + u.mu.Lock() + defer u.mu.Unlock() + u.session = s + if u.ready != nil { + close(u.ready) + u.ready = nil + } +} + +func (u *upstreamConn) markUnready() { + u.mu.Lock() + defer u.mu.Unlock() + u.session = nil + if u.ready == nil { + u.ready = make(chan struct{}) + } +} diff --git a/demo/code-mode/mcp-stdio-proxy/main_test.go b/demo/code-mode/mcp-stdio-proxy/main_test.go index 0086a70b2b..a3a6e2a9d8 100644 --- a/demo/code-mode/mcp-stdio-proxy/main_test.go +++ b/demo/code-mode/mcp-stdio-proxy/main_test.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" @@ -149,6 +150,93 @@ func TestProxyMirrorsUpstreamSurfaceAndForwardsElicitation(t *testing.T) { } } +func TestProxyReconnectsAfterUpstreamDisconnect(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + server := mcp.NewServer(&mcp.Implementation{Name: "test-upstream", Version: "0.1.0"}, nil) + server.AddTool(&mcp.Tool{ + Name: "echo", + Description: "Echo the input.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: string(req.Params.Arguments)}}, + StructuredContent: req.Params.Arguments, + }, nil + }) + mcpHandler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + + // Switchable handler: when "off", every request returns 503 so both the + // keepalive ping on the live session and any reconnect dials fail. + var upstreamUp atomic.Bool + upstreamUp.Store(true) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !upstreamUp.Load() { + http.Error(w, "upstream off", http.StatusServiceUnavailable) + return + } + mcpHandler.ServeHTTP(w, r) + })) + defer httpServer.Close() + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + errCh := make(chan error, 1) + go func() { + errCh <- runProxy(ctx, proxyOptions{ + upstreamURL: httpServer.URL, + transport: serverTransport, + httpClient: httpServer.Client(), + keepAlive: 100 * time.Millisecond, + initialBackoff: 50 * time.Millisecond, + }) + }() + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "0.1.0"}, nil) + session, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + defer func() { + require.NoError(t, session.Close()) + err := <-errCh + if !errors.Is(err, context.Canceled) { + require.NoError(t, err) + } + }() + + resp, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{"x": 1}, + }) + require.NoError(t, err) + assert.Equal(t, &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"x":1}`}}, + StructuredContent: map[string]any{"x": float64(1)}, + }, resp) + + upstreamUp.Store(false) + time.Sleep(400 * time.Millisecond) + upstreamUp.Store(true) + + require.Eventually(t, func() bool { + resp, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{"x": 2}, + }) + if err != nil { + return false + } + return assert.ObjectsAreEqual(&mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"x":2}`}}, + StructuredContent: map[string]any{"x": float64(2)}, + }, resp) + }, 10*time.Second, 100*time.Millisecond, "expected proxy to reconnect and serve calls") +} + func newTestUpstream(t *testing.T) *httptest.Server { t.Helper() From 939da747e836bc5ed6466fdd695b054982972b80 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 7 May 2026 12:31:07 +0200 Subject: [PATCH 05/10] fix(router): unblock long code_mode_run_js calls and drop dead executeState.wg WriteTimeout was 30s while executeTimeout is 120s, so net/http would cut off legitimately long code_mode_run_js responses mid-flight. Lift the write deadline above the configured execute timeout and switch the listener to ReadHeaderTimeout so a body upload doesn't share the 30s budget either. The unused sync.WaitGroup on executeState (and its matching state.wg.Wait() in Execute) was a no-op leftover; remove both. Co-Authored-By: Claude Opus 4.7 (1M context) --- router/internal/codemode/sandbox/execute.go | 1 - router/internal/codemode/sandbox/host.go | 1 - router/internal/codemode/server/server.go | 32 ++++++++++++--------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/router/internal/codemode/sandbox/execute.go b/router/internal/codemode/sandbox/execute.go index 1a4a2f22f0..37747a83bc 100644 --- a/router/internal/codemode/sandbox/execute.go +++ b/router/internal/codemode/sandbox/execute.go @@ -61,7 +61,6 @@ func (s *Sandbox) Execute(ctx context.Context, req ExecuteRequest) (execResult E qctx := rt.Context() state := &executeState{req: req} defer func() { - state.wg.Wait() // qjs panics on Close when the runtime context has already been cancelled. // Treat the runtime as best-effort cleanup; a leaked WASM instance is bounded // by GC and the per-call freshness contract. diff --git a/router/internal/codemode/sandbox/host.go b/router/internal/codemode/sandbox/host.go index c0fce81cc8..e81509be36 100644 --- a/router/internal/codemode/sandbox/host.go +++ b/router/internal/codemode/sandbox/host.go @@ -28,7 +28,6 @@ type executeState struct { req ExecuteRequest hostCalls atomic.Int32 qjsMu sync.Mutex - wg sync.WaitGroup } func (s *Sandbox) installHostInvoke(ctx context.Context, qctx *qjs.Context, state *executeState) { diff --git a/router/internal/codemode/server/server.go b/router/internal/codemode/server/server.go index c4c0eaee0f..1bf1e84a3f 100644 --- a/router/internal/codemode/server/server.go +++ b/router/internal/codemode/server/server.go @@ -28,11 +28,11 @@ import ( ) const ( - defaultListenAddr = "localhost:5027" - defaultExecuteTimeout = 120 * time.Second - defaultMaxResultBytes = 32 << 10 - mcpPath = "/mcp" - persistedOpsURI = "yoko://persisted-ops.d.ts" + defaultListenAddr = "localhost:5027" + defaultExecuteTimeout = 120 * time.Second + defaultMaxResultBytes = 32 << 10 + mcpPath = "/mcp" + persistedOpsURI = "yoko://persisted-ops.d.ts" statelessNamedOpsWarnMessage = "code mode named operations are disabled because MCP session stateless mode is enabled" namedOpsDisabledMessage = "named operations are disabled" ) @@ -74,9 +74,9 @@ type Server struct { tracerProvider trace.TracerProvider callTraceRecorder calltrace.Recorder - mcpServer *mcp.Server - searchGroup singleflight.Group - newOpsFragment func([]storage.SessionOp, *ast.Document) (string, error) + mcpServer *mcp.Server + searchGroup singleflight.Group + opsFragment func([]storage.SessionOp, *ast.Document) (string, error) mu sync.Mutex httpServer *http.Server @@ -132,7 +132,7 @@ func New(cfg Config) (*Server, error) { meter: meter, tracerProvider: cfg.TracerProvider, callTraceRecorder: cfg.CallTraceRecorder, - newOpsFragment: tsgen.NewOpsFragment, + opsFragment: tsgen.NewOpsFragment, } s.mcpServer = mcp.NewServer(&mcp.Implementation{ @@ -174,12 +174,16 @@ func (s *Server) Start(ctx context.Context) error { return err } + // WriteTimeout must exceed executeTimeout — net/http enforces it as a + // hard deadline on the whole response phase, which would cut off + // legitimately long code_mode_run_js calls. ReadHeaderTimeout bounds the + // header read separately so the listener still resists slow-loris clients. httpServer := &http.Server{ - Addr: s.listenAddr, - Handler: s.handler(), - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 60 * time.Second, + Addr: s.listenAddr, + Handler: s.handler(), + ReadHeaderTimeout: 30 * time.Second, + WriteTimeout: s.executeTimeout + 30*time.Second, + IdleTimeout: 60 * time.Second, } s.mu.Lock() From b92593b22753d5b8c331bb8d3c40b1ab0c3224e5 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 7 May 2026 13:06:54 +0200 Subject: [PATCH 06/10] feat(code-mode): strip auth directives from demo subgraph schemas MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Employee.startDate field carries @requiresScopes in the shared demo/pkg/subgraphs/employees schema; the router's authorizer is always active and rejects unauthenticated callers as soon as a scoped field is touched, breaking the code-mode demo. The shared schema is exercised by router-tests/security/authentication_test.go so it has to stay intact — instead, prepare-schemas.sh copies the four code-mode subgraph schemas into demo/code-mode/schemas/ with @authenticated and @requiresScopes directive applications removed, and graph.yaml composes from those copies. The make compose target depends on prepare-schemas so the local copies always track upstream. Co-Authored-By: Claude Opus 4.7 (1M context) --- demo/code-mode/.gitignore | 1 + demo/code-mode/Makefile | 9 ++++++-- demo/code-mode/graph.yaml | 8 ++++---- demo/code-mode/prepare-schemas.sh | 34 +++++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 6 deletions(-) create mode 100755 demo/code-mode/prepare-schemas.sh diff --git a/demo/code-mode/.gitignore b/demo/code-mode/.gitignore index bc5fd710be..c54f22c490 100644 --- a/demo/code-mode/.gitignore +++ b/demo/code-mode/.gitignore @@ -1 +1,2 @@ mcp-stdio-proxy/mcp-stdio-proxy +schemas/ diff --git a/demo/code-mode/Makefile b/demo/code-mode/Makefile index 1114f6ea7c..2c7f1aec8f 100644 --- a/demo/code-mode/Makefile +++ b/demo/code-mode/Makefile @@ -3,7 +3,7 @@ GOCACHE ?= /tmp/cosmo-code-mode-go-build-cache wgc_env_arg = $(if $(wildcard ../cli/.env),--env-file ../cli/.env,) wgc_router = pnpm dlx tsx $(wgc_env_arg) ../cli/src/index.ts router -.PHONY: build-yoko build-stdio-proxy compose start down run-subgraphs +.PHONY: build-yoko build-stdio-proxy prepare-schemas compose start down run-subgraphs build-yoko: mkdir -p $(GOCACHE) @@ -13,7 +13,12 @@ build-stdio-proxy: mkdir -p $(GOCACHE) cd mcp-stdio-proxy && GOCACHE=$(GOCACHE) go build -o mcp-stdio-proxy . -compose: +# Generate code-mode-local copies of the demo subgraph schemas with auth +# directives stripped so the demo runs without authentication. +prepare-schemas: + ./prepare-schemas.sh + +compose: prepare-schemas cd .. && if [ -f ../cli/dist/src/index.js ]; then \ DISABLE_UPDATE_CHECK=true node ../cli/dist/src/index.js router compose -i ./code-mode/graph.yaml -o ./code-mode/config.json; \ else \ diff --git a/demo/code-mode/graph.yaml b/demo/code-mode/graph.yaml index e95412def2..c67180f295 100644 --- a/demo/code-mode/graph.yaml +++ b/demo/code-mode/graph.yaml @@ -3,16 +3,16 @@ subgraphs: - name: employees routing_url: http://localhost:4001/graphql schema: - file: ../pkg/subgraphs/employees/subgraph/schema.graphqls + file: schemas/employees.graphqls - name: family routing_url: http://localhost:4002/graphql schema: - file: ../pkg/subgraphs/family/subgraph/schema.graphqls + file: schemas/family.graphqls - name: availability routing_url: http://localhost:4007/graphql schema: - file: ../pkg/subgraphs/availability/subgraph/schema.graphqls + file: schemas/availability.graphqls - name: mood routing_url: http://localhost:4008/graphql schema: - file: ../pkg/subgraphs/mood/subgraph/schema.graphqls + file: schemas/mood.graphqls diff --git a/demo/code-mode/prepare-schemas.sh b/demo/code-mode/prepare-schemas.sh new file mode 100755 index 0000000000..92fd8f6dae --- /dev/null +++ b/demo/code-mode/prepare-schemas.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +# Generate code-mode-local copies of the demo subgraph schemas with the +# federation auth directives (@authenticated, @requiresScopes) stripped. +# +# The shared schemas under demo/pkg/subgraphs are used by router-tests and +# other demos that intentionally exercise authorization, so we don't touch +# them. The code-mode demo runs without authentication, and the router's +# CosmoAuthorizer always rejects unauthenticated requests on a scoped field; +# composing from these stripped copies keeps the demo working out of the box. + +set -Eeuo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SRC_DIR="$SCRIPT_DIR/../pkg/subgraphs" +OUT_DIR="$SCRIPT_DIR/schemas" + +mkdir -p "$OUT_DIR" + +strip_auth() { + local in="$1" + local out="$2" + # Remove @requiresScopes(scopes: [[...], [...]]) — match the doubly-nested + # bracket payload, then drop @authenticated standalone uses. The directive + # imports inside @link(import: [...]) stay (they're string literals, not + # directive applications, so they don't trigger enforcement). + sed -E ' + s/[[:space:]]*@requiresScopes\(scopes:[[:space:]]*\[(\[[^][]*\][, ]*)+\]\)//g + s/[[:space:]]*@authenticated\b//g + ' "$in" > "$out" +} + +for sg in employees family availability mood; do + strip_auth "$SRC_DIR/$sg/subgraph/schema.graphqls" "$OUT_DIR/$sg.graphqls" +done From 3795fc2cafb758c47beb855e097cf25aad3c1767 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 7 May 2026 15:45:03 +0200 Subject: [PATCH 07/10] refactor(code-mode): identify operations by content SHA instead of name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Operation identity is now the ShortSHA over the canonical body — eight hex chars prefixed with 'o' so it's a valid JS identifier. The model calls tools.(...) inside the sandbox, the search response surfaces (sha, description) pairs plus a TS fragment that names each method by sha, and storage dedupes by canonical body alone. Previously, yoko regenerating an operation under the same document name ("getOrders" with v1 vs. v2 selection sets) silently overwrote the first registration under NormalizeName-based dedup. Tying identity to content makes that collision impossible — different bodies always get different identifiers, identical bodies always converge. DocumentName is preserved separately because the router's GraphQL parser still needs the literal operation name from the body when invoking against /graphql. NormalizeName, SuffixedName, and the reserved-word table are gone; tests are rewritten to compute SHAs from bodies so expectations stay self-checking. router-tests duplicates a small codeModeShortSHA helper because it's a separate module and can't import internal/codemode/storage. Co-Authored-By: Claude Opus 4.7 (1M context) --- router-tests/code_mode_named_ops_test.go | 142 +++--- router/internal/codemode/sandbox/host.go | 12 +- .../codemode/server/search_handler.go | 225 +++++---- .../codemode/server/search_handler_test.go | 457 +++++++++++++----- .../codemode/storage/memory_backend.go | 46 +- .../codemode/storage/memory_backend_test.go | 260 +++++++--- router/internal/codemode/storage/naming.go | 212 ++------ .../internal/codemode/storage/naming_test.go | 98 ++-- .../codemode/storage/redis_backend.go | 33 +- .../codemode/storage/redis_backend_test.go | 148 +++++- router/internal/codemode/storage/types.go | 22 +- 11 files changed, 987 insertions(+), 668 deletions(-) diff --git a/router-tests/code_mode_named_ops_test.go b/router-tests/code_mode_named_ops_test.go index 40ea76a5fb..c58985b5f2 100644 --- a/router-tests/code_mode_named_ops_test.go +++ b/router-tests/code_mode_named_ops_test.go @@ -22,6 +22,9 @@ import ( "github.com/wundergraph/cosmo/router-tests/freeport" "github.com/wundergraph/cosmo/router-tests/testenv" + "crypto/sha256" + "encoding/hex" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" yokoconnect "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" @@ -42,16 +45,34 @@ const ( updateTagMutation = `mutation updateEmployeeTag($id: Int!, $tag: String!) { updateEmployeeTag(id: $id, tag: $tag) { id tag } }` ) -const firstEmployeeTS = `/** Fetch the first employee. */ -firstEmployee(): R<{ firstEmployee: { id: number; details: { forename: string; surname: string } | null } }>;` +// codeModeShortSHA mirrors router/internal/codemode/storage.ShortSHA. We +// duplicate it because router-tests is a separate module and can't import +// the internal package; if the production helper changes, this must too. +func codeModeShortSHA(body string) string { + canonical := strings.Join(strings.Fields(body), " ") + sum := sha256.Sum256([]byte(canonical)) + return "o" + hex.EncodeToString(sum[:])[:8] +} + +// SHA-derived JS identifiers — operations are exposed to the model as +// tools.(...) so collisions on document name don't conflate distinct +// bodies. Computed once at init to keep test expectations readable. +var ( + firstEmployeeSHA = codeModeShortSHA(firstEmployeeQuery) + employeeByIDSHA = codeModeShortSHA(employeeByIDQuery) + updateTagSHA = codeModeShortSHA(updateTagMutation) +) + +var firstEmployeeTS = `/** Fetch the first employee. */ +` + firstEmployeeSHA + `(): R<{ firstEmployee: { id: number; details: { forename: string; surname: string } | null } }>;` -const employeeByIDTS = `/** Fetch employee by id. */ -employeeByID(vars: { id: number }): R<{ employee: { id: number; details: { forename: string; surname: string } | null } | null }>;` +var employeeByIDTS = `/** Fetch employee by id. */ +` + employeeByIDSHA + `(vars: { id: number }): R<{ employee: { id: number; details: { forename: string; surname: string } | null } | null }>;` -const updateTagTS = `/** Update employee tag. */ -updateEmployeeTag(vars: { id: number; tag: string }): R<{ updateEmployeeTag: { id: number; tag: string } | null }>;` +var updateTagTS = `/** Update employee tag. */ +` + updateTagSHA + `(vars: { id: number; tag: string }): R<{ updateEmployeeTag: { id: number; tag: string } | null }>;` -const twoOpsFragment = firstEmployeeTS + "\n\n" + employeeByIDTS +var twoOpsFragment = firstEmployeeTS + "\n\n" + employeeByIDTS // indentBundleEntry mirrors tsgen's behavior: every line of a per-op block // (JSDoc + signature) is indented by 2 spaces inside the tools object. @@ -61,7 +82,6 @@ func indentBundleEntry(s string) string { const emptyOpsBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; -// Known limitation: union and interface selections are typed as unknown. declare const tools: {}; @@ -70,7 +90,6 @@ declare function compact(value: T): T;` var firstEmployeeBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; -// Known limitation: union and interface selections are typed as unknown. declare const tools: { ` + indentBundleEntry(firstEmployeeTS) + ` @@ -81,7 +100,6 @@ declare function compact(value: T): T;` var employeeByIDBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; -// Known limitation: union and interface selections are typed as unknown. declare const tools: { ` + indentBundleEntry(employeeByIDTS) + ` @@ -92,7 +110,6 @@ declare function compact(value: T): T;` var twoOpsBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; -// Known limitation: union and interface selections are typed as unknown. declare const tools: { ` + indentBundleEntry(firstEmployeeTS) + ` @@ -115,12 +132,11 @@ func TestCodeModeNamedOpsMemoryBackendStatefulSearchExecuteAndResource(t *testin "prompts": []string{"first employee", "employee by id"}, }) assert.Equal(t, twoOpsFragment, searchText) - assert.Equal(t, []*yokov1.IndexRequest{{SchemaSdl: yoko.indexRequests()[0].GetSchemaSdl()}}, yoko.indexRequests()) - assert.Equal(t, []*yokov1.SearchRequest{{ - Prompts: []string{"first employee", "employee by id"}, - SchemaId: "schema-1", - SessionId: yoko.searchRequests()[0].GetSessionId(), - }}, yoko.searchRequests()) + assert.Equal(t, []*yokov1.IndexSchemaRequest{{Sdl: yoko.indexRequests()[0].GetSdl()}}, yoko.indexRequests()) + assert.Equal(t, []*yokov1.GenerateQueryRequest{ + {Prompt: "first employee", SchemaId: "schema-1"}, + {Prompt: "employee by id", SchemaId: "schema-1"}, + }, yoko.generateRequests()) resource := readPersistedOpsResource(t, ctx, session) assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ @@ -130,7 +146,7 @@ func TestCodeModeNamedOpsMemoryBackendStatefulSearchExecuteAndResource(t *testin }}}, resource) executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ - "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + "source": fmt.Sprintf(`async () => { return await tools.%s({ id: 1 }); }`, employeeByIDSHA), }) assert.Equal(t, map[string]any{ "result": map[string]any{ @@ -188,13 +204,13 @@ func TestCodeModeNamedOpsSchemaReloadEvictsSession(t *testing.T) { require.NoError(t, poller.updateConfig(poller.initConfig, "before-code-mode-reload")) executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ - "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + "source": fmt.Sprintf(`async () => { return await tools.%s({ id: 1 }); }`, employeeByIDSHA), }) assert.Equal(t, map[string]any{ "result": nil, "error": map[string]any{ "name": "TypeError", - "message": "tools.employeeByID is not a function", + "message": fmt.Sprintf("tools.%s is not a function", employeeByIDSHA), "stack": " at __agentMain (codemode_agent.js:agent.ts:1:34)\n at (codemode_agent.js:73:42)\n at (codemode_agent.js:77:1)\n", }, }, decodeJSON(t, executeText)) @@ -215,7 +231,7 @@ func TestCodeModeNamedOpsMutationElicitationRejection(t *testing.T) { assert.Equal(t, updateTagTS, searchText) executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ - "source": `async () => { return await tools.updateEmployeeTag({ id: 1, tag: "x" }); }`, + "source": fmt.Sprintf(`async () => { return await tools.%s({ id: 1, tag: "x" }); }`, updateTagSHA), }) assert.Equal(t, map[string]any{ "result": map[string]any{ @@ -334,7 +350,7 @@ func TestCodeModeNamedOpsRedisBackendTransparent(t *testing.T) { assert.Equal(t, twoOpsBundle, resource.Contents[0].Text) executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ - "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + "source": fmt.Sprintf(`async () => { return await tools.%s({ id: 1 }); }`, employeeByIDSHA), }) assert.Equal(t, map[string]any{ "result": map[string]any{ @@ -504,33 +520,33 @@ func mark3ResourcesContain(resources []mark3mcp.Resource, uri string) bool { } type fakeCodeModeYoko struct { - mu sync.Mutex - indexCounter int - indexRequestLog []*yokov1.IndexRequest - searchRequestLog []*yokov1.SearchRequest - opsByPrompt map[string]*yokov1.GeneratedOperation + mu sync.Mutex + indexCounter int + indexRequestLog []*yokov1.IndexSchemaRequest + generateRequestLog []*yokov1.GenerateQueryRequest + queriesByPrompt map[string]*yokov1.ResolvedQuery } func newFakeCodeModeYoko() *fakeCodeModeYoko { return &fakeCodeModeYoko{ - opsByPrompt: map[string]*yokov1.GeneratedOperation{ + queriesByPrompt: map[string]*yokov1.ResolvedQuery{ "first employee": { - Name: firstEmployeeOpName, - Body: firstEmployeeQuery, - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, - Description: "Fetch the first employee.", + OperationName: firstEmployeeOpName, + Document: firstEmployeeQuery, + OperationType: "query", + Description: "Fetch the first employee.", }, "employee by id": { - Name: employeeByIDOpName, - Body: employeeByIDQuery, - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, - Description: "Fetch employee by id.", + OperationName: employeeByIDOpName, + Document: employeeByIDQuery, + OperationType: "query", + Description: "Fetch employee by id.", }, "update employee tag": { - Name: updateTagOpName, - Body: updateTagMutation, - Kind: yokov1.OperationKind_OPERATION_KIND_MUTATION, - Description: "Update employee tag.", + OperationName: updateTagOpName, + Document: updateTagMutation, + OperationType: "mutation", + Description: "Update employee tag.", }, }, } @@ -551,50 +567,48 @@ func startFakeCodeModeYoko(t *testing.T, svc *fakeCodeModeYoko) *httptest.Server return server } -func (f *fakeCodeModeYoko) Index(_ context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { +func (f *fakeCodeModeYoko) IndexSchema(_ context.Context, req *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { f.mu.Lock() defer f.mu.Unlock() f.indexCounter++ - f.indexRequestLog = append(f.indexRequestLog, &yokov1.IndexRequest{SchemaSdl: req.Msg.GetSchemaSdl()}) - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: fmt.Sprintf("schema-%d", f.indexCounter)}), nil + f.indexRequestLog = append(f.indexRequestLog, &yokov1.IndexSchemaRequest{Sdl: req.Msg.GetSdl()}) + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: fmt.Sprintf("schema-%d", f.indexCounter)}), nil } -func (f *fakeCodeModeYoko) Search(_ context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { +func (f *fakeCodeModeYoko) GenerateQuery(_ context.Context, req *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) { f.mu.Lock() defer f.mu.Unlock() - f.searchRequestLog = append(f.searchRequestLog, &yokov1.SearchRequest{ - Prompts: append([]string(nil), req.Msg.GetPrompts()...), - SchemaId: req.Msg.GetSchemaId(), - SessionId: req.Msg.GetSessionId(), + f.generateRequestLog = append(f.generateRequestLog, &yokov1.GenerateQueryRequest{ + SchemaId: req.Msg.GetSchemaId(), + Prompt: req.Msg.GetPrompt(), }) - ops := make([]*yokov1.GeneratedOperation, 0, len(req.Msg.GetPrompts())) - for _, prompt := range req.Msg.GetPrompts() { - if op := f.opsByPrompt[prompt]; op != nil { - ops = append(ops, op) - } + queries := make([]*yokov1.ResolvedQuery, 0, 1) + if q := f.queriesByPrompt[req.Msg.GetPrompt()]; q != nil { + queries = append(queries, q) } - return connect.NewResponse(&yokov1.SearchResponse{Operations: ops}), nil + return connect.NewResponse(&yokov1.GenerateQueryResponse{ + Resolution: &yokov1.Resolution{Queries: queries}, + }), nil } -func (f *fakeCodeModeYoko) indexRequests() []*yokov1.IndexRequest { +func (f *fakeCodeModeYoko) indexRequests() []*yokov1.IndexSchemaRequest { f.mu.Lock() defer f.mu.Unlock() - out := make([]*yokov1.IndexRequest, 0, len(f.indexRequestLog)) + out := make([]*yokov1.IndexSchemaRequest, 0, len(f.indexRequestLog)) for _, req := range f.indexRequestLog { - out = append(out, &yokov1.IndexRequest{SchemaSdl: req.GetSchemaSdl()}) + out = append(out, &yokov1.IndexSchemaRequest{Sdl: req.GetSdl()}) } return out } -func (f *fakeCodeModeYoko) searchRequests() []*yokov1.SearchRequest { +func (f *fakeCodeModeYoko) generateRequests() []*yokov1.GenerateQueryRequest { f.mu.Lock() defer f.mu.Unlock() - out := make([]*yokov1.SearchRequest, 0, len(f.searchRequestLog)) - for _, req := range f.searchRequestLog { - out = append(out, &yokov1.SearchRequest{ - Prompts: append([]string(nil), req.GetPrompts()...), - SchemaId: req.GetSchemaId(), - SessionId: req.GetSessionId(), + out := make([]*yokov1.GenerateQueryRequest, 0, len(f.generateRequestLog)) + for _, req := range f.generateRequestLog { + out = append(out, &yokov1.GenerateQueryRequest{ + SchemaId: req.GetSchemaId(), + Prompt: req.GetPrompt(), }) } return out diff --git a/router/internal/codemode/sandbox/host.go b/router/internal/codemode/sandbox/host.go index e81509be36..7cddf4b952 100644 --- a/router/internal/codemode/sandbox/host.go +++ b/router/internal/codemode/sandbox/host.go @@ -134,9 +134,19 @@ func (s *Sandbox) invokeTool(ctx context.Context, state *executeState, name stri } } + // The operation name passed to /graphql must match the named operation + // inside op.Body. op.Name is the content-derived ShortSHA we expose to + // the model as `tools.`, but the document body still carries + // yoko's original operation name — so we send op.DocumentName when + // available, falling back to op.Name for legacy sessions written + // before this field existed. + opName := op.DocumentName + if opName == "" { + opName = name + } body, err := json.Marshal(graphQLRequest{ Query: op.Body, - OperationName: name, + OperationName: opName, Variables: vars, }) if err != nil { diff --git a/router/internal/codemode/server/search_handler.go b/router/internal/codemode/server/search_handler.go index 8860cc3eca..2aa1913085 100644 --- a/router/internal/codemode/server/search_handler.go +++ b/router/internal/codemode/server/search_handler.go @@ -15,12 +15,8 @@ import ( ) const ( - maxSearchPrompts = 20 - emptySearchAPIResponseMessage = "// 0 new ops; previous code_mode_search_tools calls already cover these prompts." - - // The generated proto currently has query and mutation constants. Yoko may - // still send the planned subscription enum value; host behavior is to drop it. - yokoOperationKindSubscription yokov1.OperationKind = 3 + maxSearchPrompts = 20 + noOperationsMessage = "// yoko returned no operations for these prompts. Restate with concrete entity/field names." ) type searchAPIInput struct { @@ -28,11 +24,17 @@ type searchAPIInput struct { } type legacyCatalogueOperation struct { - Name string `json:"name"` - Body string `json:"body"` - Kind string `json:"kind"` - Description string `json:"description"` - Variables *string `json:"variables"` + Name string `json:"name"` + Body string `json:"body"` + Kind string `json:"kind"` + Description string `json:"description"` + VariablesSchema string `json:"variables_schema,omitempty"` +} + +type legacyCatalogueResponse struct { + Operations []legacyCatalogueOperation `json:"operations"` + Unsatisfied []string `json:"unsatisfied,omitempty"` + Truncated bool `json:"truncated,omitempty"` } func (s *Server) handleSearchAPI(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -83,39 +85,36 @@ func decodeSearchPrompts(req *mcp.CallToolRequest) ([]string, error) { } func (s *Server) handleSearchStateless(ctx context.Context, prompts []string) *mcp.CallToolResult { - response, err := s.searchYoko(ctx, "", prompts) + resolution, err := s.searchYoko(ctx, prompts) if err != nil { return toolErrorResult(fmt.Sprintf("code_mode_search_tools: yoko search failed: %v", err)) } - catalogue := make([]legacyCatalogueOperation, 0, len(response.GetOperations())) - droppedSubscription := false - for _, op := range response.GetOperations() { - kind, ok, subscription := yokoOperationKindLabel(op.GetKind()) - if subscription { - droppedSubscription = true - continue - } + catalogue := make([]legacyCatalogueOperation, 0, len(resolution.GetQueries())) + for _, q := range resolution.GetQueries() { + kind, ok := operationKindLabel(q.GetOperationType()) if !ok { s.logger.Warn("code_mode_search_tools dropped unsupported operation kind", - zap.String("name", op.GetName()), - zap.String("kind", op.GetKind().String()), + zap.String("name", q.GetOperationName()), + zap.String("kind", q.GetOperationType()), ) continue } catalogue = append(catalogue, legacyCatalogueOperation{ - Name: op.GetName(), - Body: op.GetBody(), - Kind: kind, - Description: op.GetDescription(), - Variables: extractGraphQLVariablesBlock(op.GetBody()), + Name: storage.ShortSHA(q.GetDocument()), + Body: q.GetDocument(), + Kind: kind, + Description: q.GetDescription(), + VariablesSchema: q.GetVariablesSchema(), }) } - if droppedSubscription { - s.logger.Warn("code_mode_search_tools dropped subscription operations returned by yoko") - } - encoded, err := json.Marshal(catalogue) + response := legacyCatalogueResponse{ + Operations: catalogue, + Unsatisfied: unsatisfiedReasons(resolution), + Truncated: resolution.GetTruncated(), + } + encoded, err := json.Marshal(response) if err != nil { return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to encode legacy catalogue: %v", err)) } @@ -123,94 +122,138 @@ func (s *Server) handleSearchStateless(ctx context.Context, prompts []string) *m } func (s *Server) handleSearchStateful(ctx context.Context, sessionID string, prompts []string) *mcp.CallToolResult { - response, err := s.searchYoko(ctx, sessionID, prompts) + resolution, err := s.searchYoko(ctx, prompts) if err != nil { return toolErrorResult(fmt.Sprintf("code_mode_search_tools: yoko search failed: %v", err)) } - rawOps := make([]storage.SessionOp, 0, len(response.GetOperations())) - droppedSubscription := false - for _, op := range response.GetOperations() { - kind, ok, subscription := storageOperationKind(op.GetKind()) - if subscription { - droppedSubscription = true - continue - } + rawOps := make([]storage.SessionOp, 0, len(resolution.GetQueries())) + for _, q := range resolution.GetQueries() { + kind, ok := storageOperationKind(q.GetOperationType()) if !ok { s.logger.Warn("code_mode_search_tools dropped unsupported operation kind", - zap.String("name", op.GetName()), - zap.String("kind", op.GetKind().String()), + zap.String("name", q.GetOperationName()), + zap.String("kind", q.GetOperationType()), ) continue } rawOps = append(rawOps, storage.SessionOp{ - Name: storage.NormalizeName(op.GetName()), - Body: op.GetBody(), - Kind: kind, - Description: op.GetDescription(), + Name: storage.ShortSHA(q.GetDocument()), + Body: q.GetDocument(), + Kind: kind, + DocumentName: q.GetOperationName(), + Description: q.GetDescription(), }) } - if droppedSubscription { - s.logger.Warn("code_mode_search_tools dropped subscription operations returned by yoko") - } + + notes := unsatisfactionNotes(resolution) if len(rawOps) == 0 { - return textResult(emptySearchAPIResponseMessage) + if notes != "" { + return textResult(notes + noOperationsMessage) + } + return textResult(noOperationsMessage) } if s.storage == nil { return toolErrorResult("code_mode_search_tools: failed to register ops: code mode storage is not configured") } - // Collision handling approach: Append-applies-suffix. The storage backend is - // the serialization point for a session and returns the final stored names. - appendedOps, err := s.storage.Append(ctx, sessionID, rawOps) + // Append returns one resolved SessionOp per input, mapping each yoko + // query to either a freshly-registered op or a pre-existing op it + // dedupes against by canonical body. Operation identity is the SHA + // over the body, so the same body always lands on the same op — yoko + // regenerating an operation under a different document name produces + // the same identifier. The model receives declarations for every + // match including reused ones, so a fresh context never has to + // introspect the session. + matchedOps, err := s.storage.Append(ctx, sessionID, rawOps) if err != nil { return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to register ops: %v", err)) } - if len(appendedOps) == 0 { - return textResult(emptySearchAPIResponseMessage) - } - rendered, err := s.newOpsFragment(appendedOps, s.storage.Schema()) + rendered, err := s.opsFragment(matchedOps, s.storage.Schema()) if err != nil { - return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to render new ops: %v", err)) + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to render ops: %v", err)) + } + if notes != "" { + rendered = notes + "\n" + rendered } return textResult(rendered) } -func (s *Server) searchYoko(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { +func (s *Server) searchYoko(ctx context.Context, prompts []string) (*yokov1.Resolution, error) { if s.yokoClient == nil { return nil, errors.New("yoko client is not configured") } - return s.yokoClient.Search(ctx, sessionID, prompts) + return s.yokoClient.Search(ctx, prompts) } -func storageOperationKind(kind yokov1.OperationKind) (storage.OperationKind, bool, bool) { - switch kind { - case yokov1.OperationKind_OPERATION_KIND_QUERY: - return storage.OperationKindQuery, true, false - case yokov1.OperationKind_OPERATION_KIND_MUTATION: - return storage.OperationKindMutation, true, false - case yokoOperationKindSubscription: - return "", false, true +func storageOperationKind(operationType string) (storage.OperationKind, bool) { + switch strings.ToLower(operationType) { + case "query": + return storage.OperationKindQuery, true + case "mutation": + return storage.OperationKindMutation, true default: - return "", false, false + return "", false } } -func yokoOperationKindLabel(kind yokov1.OperationKind) (string, bool, bool) { - switch kind { - case yokov1.OperationKind_OPERATION_KIND_QUERY: - return "Query", true, false - case yokov1.OperationKind_OPERATION_KIND_MUTATION: - return "Mutation", true, false - case yokoOperationKindSubscription: - return "", false, true +func operationKindLabel(operationType string) (string, bool) { + switch strings.ToLower(operationType) { + case "query": + return "Query", true + case "mutation": + return "Mutation", true default: - return "", false, false + return "", false } } +func unsatisfiedReasons(resolution *yokov1.Resolution) []string { + items := resolution.GetUnsatisfied() + if len(items) == 0 { + return nil + } + out := make([]string, 0, len(items)) + for _, u := range items { + reason := strings.TrimSpace(u.GetReason()) + if reason == "" { + continue + } + out = append(out, reason) + } + if len(out) == 0 { + return nil + } + return out +} + +// unsatisfactionNotes formats unsatisfied requirements (and the truncated flag) +// as a leading TS-comment block prepended to the bundle fragment, so the model +// reading the search response can see what could not be satisfied. +func unsatisfactionNotes(resolution *yokov1.Resolution) string { + reasons := unsatisfiedReasons(resolution) + truncated := resolution.GetTruncated() + if len(reasons) == 0 && !truncated { + return "" + } + + var b strings.Builder + if len(reasons) > 0 { + b.WriteString("// unsatisfied: yoko could not satisfy the following requirement(s):\n") + for _, reason := range reasons { + b.WriteString("// - ") + b.WriteString(reason) + b.WriteByte('\n') + } + } + if truncated { + b.WriteString("// truncated: yoko ran out of turns before committing every requirement; consider tightening the prompt.\n") + } + return b.String() +} + func searchSingleFlightKey(sessionID string, prompts []string) string { sortedPrompts := append([]string(nil), prompts...) sort.Strings(sortedPrompts) @@ -221,32 +264,6 @@ func searchSingleFlightKey(sessionID string, prompts []string) string { return strings.Join(keyParts, "|") } -func extractGraphQLVariablesBlock(body string) *string { - open := strings.IndexByte(body, '(') - if open < 0 { - return nil - } - selection := strings.IndexByte(body, '{') - if selection >= 0 && selection < open { - return nil - } - - depth := 0 - for i := open; i < len(body); i++ { - switch body[i] { - case '(': - depth++ - case ')': - depth-- - if depth == 0 { - value := strings.TrimSpace(body[open : i+1]) - return &value - } - } - } - return nil -} - func textResult(text string) *mcp.CallToolResult { return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: text}}, diff --git a/router/internal/codemode/server/search_handler_test.go b/router/internal/codemode/server/search_handler_test.go index df4c9fdeff..df65cdd15a 100644 --- a/router/internal/codemode/server/search_handler_test.go +++ b/router/internal/codemode/server/search_handler_test.go @@ -49,7 +49,7 @@ type Customer { } ` -const emptySearchMessage = "// 0 new ops; previous code_mode_search_tools calls already cover these prompts." +const noQueriesFromYokoMessage = "// yoko returned no operations for these prompts. Restate with concrete entity/field names." func TestHandleSearchValidatesPrompts(t *testing.T) { tests := []struct { @@ -99,20 +99,23 @@ func TestHandleSearchValidatesPrompts(t *testing.T) { func TestHandleSearchStatelessReturnsLegacyJSONCatalogue(t *testing.T) { searcher := newFakeYoko() - searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{ - { - Name: "getOrders", - Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, - Description: "Fetch orders.", - }, - { - Name: "watchOrders", - Body: "subscription WatchOrders { orders { id } }", - Kind: yokoOperationKindSubscription, - Description: "Watch orders.", + searcher.responses <- &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{ + { + OperationName: "getOrders", + Document: "query GetOrders($limit: Int) { orders(limit: $limit) { id } }", + OperationType: "query", + Description: "Fetch orders.", + VariablesSchema: `{"type":"object","properties":{"limit":{"type":["integer","null"]}}}`, + }, + { + OperationName: "watchOrders", + Document: "subscription WatchOrders { orders { id } }", + OperationType: "subscription", + Description: "Watch orders.", + }, }, - }} + } store := newSearchTestStorage(t) srv := newSearchTestServer(t, true, searcher, store) @@ -121,42 +124,47 @@ func TestHandleSearchStatelessReturnsLegacyJSONCatalogue(t *testing.T) { })) require.NoError(t, err) - expectedJSON := mustJSON(t, []legacyCatalogueEntry{ - { - Name: "getOrders", - Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id } }", - Kind: "Query", - Description: "Fetch orders.", - Variables: ptrString("($limit: Int)"), + getOrdersBody := "query GetOrders($limit: Int) { orders(limit: $limit) { id } }" + expectedJSON := mustJSON(t, legacyCatalogueResponse{ + Operations: []legacyCatalogueOperation{ + { + Name: storage.ShortSHA(getOrdersBody), + Body: getOrdersBody, + Kind: "Query", + Description: "Fetch orders.", + VariablesSchema: `{"type":"object","properties":{"limit":{"type":["integer","null"]}}}`, + }, }, }) assert.Equal(t, textToolResult(expectedJSON), got) - assert.Equal(t, []searchCall{{sessionID: "", prompts: []string{"orders"}}}, searcher.callsSnapshot()) + assert.Equal(t, []searchCall{{prompts: []string{"orders"}}}, searcher.callsSnapshot()) assert.Equal(t, []storage.SessionOp(nil), store.opsSnapshot("session-1")) } func TestHandleSearchStatefulAppendsAndReturnsNewOpsFragment(t *testing.T) { searcher := newFakeYoko() - searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{ - { - Name: "getOrders", - Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, - Description: "Fetch orders.", - }, - { - Name: "cancelOrder", - Body: "mutation CancelOrder($id: ID!) { cancelOrder(id: $id) { id } }", - Kind: yokov1.OperationKind_OPERATION_KIND_MUTATION, - Description: "Cancel an order.", - }, - { - Name: "watchOrders", - Body: "subscription WatchOrders { orders { id } }", - Kind: yokoOperationKindSubscription, - Description: "Watch orders.", + searcher.responses <- &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{ + { + OperationName: "getOrders", + Document: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + OperationType: "query", + Description: "Fetch orders.", + }, + { + OperationName: "cancelOrder", + Document: "mutation CancelOrder($id: ID!) { cancelOrder(id: $id) { id } }", + OperationType: "mutation", + Description: "Cancel an order.", + }, + { + OperationName: "watchOrders", + Document: "subscription WatchOrders { orders { id } }", + OperationType: "subscription", + Description: "Watch orders.", + }, }, - }} + } store := newSearchTestStorage(t) srv := newSearchTestServer(t, false, searcher, store) @@ -165,34 +173,164 @@ func TestHandleSearchStatefulAppendsAndReturnsNewOpsFragment(t *testing.T) { })) require.NoError(t, err) + getOrdersBody := "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }" + cancelOrderBody := "mutation CancelOrder($id: ID!) { cancelOrder(id: $id) { id } }" wantOps := []storage.SessionOp{ { - Name: "getOrders", - Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", - Kind: storage.OperationKindQuery, - Description: "Fetch orders.", + Name: storage.ShortSHA(getOrdersBody), + Body: getOrdersBody, + Kind: storage.OperationKindQuery, + DocumentName: "getOrders", + Description: "Fetch orders.", }, { - Name: "cancelOrder", - Body: "mutation CancelOrder($id: ID!) { cancelOrder(id: $id) { id } }", - Kind: storage.OperationKindMutation, - Description: "Cancel an order.", + Name: storage.ShortSHA(cancelOrderBody), + Body: cancelOrderBody, + Kind: storage.OperationKindMutation, + DocumentName: "cancelOrder", + Description: "Cancel an order.", }, } wantFragment, err := tsgen.NewOpsFragment(wantOps, searchHandlerTestSchema(t)) require.NoError(t, err) assert.Equal(t, textToolResult(wantFragment), got) assert.Equal(t, wantOps, store.opsSnapshot("session-1")) - assert.Equal(t, []searchCall{{sessionID: "session-1", prompts: []string{"orders", "cancel order"}}}, searcher.callsSnapshot()) + assert.Equal(t, []searchCall{{prompts: []string{"orders", "cancel order"}}}, searcher.callsSnapshot()) } -func TestHandleSearchFallsBackToStatelessWhenSessionIDMissing(t *testing.T) { +func TestHandleSearchStatefulHashesNameButPreservesDocumentName(t *testing.T) { + // Regression: yoko returns operation_name in any casing it likes, + // and the same document name can mask different bodies. Storage Name + // must be the content-derived ShortSHA so collisions on the document + // name don't conflate distinct operations, but DocumentName must be + // the original name so the host bridge can match the operation + // inside Body when invoking. searcher := newFakeYoko() - searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ - Name: "getOrders", - Body: "query GetOrders { orders { id } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + searcher.responses <- &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{{ + OperationName: "GetCustomerContractDetails", + Document: "query GetCustomerContractDetails { orders { id } }", + OperationType: "query", + Description: "Fetch contract details.", + }}, + } + store := newSearchTestStorage(t) + srv := newSearchTestServer(t, false, searcher, store) + + _, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"contract"}, + })) + + require.NoError(t, err) + body := "query GetCustomerContractDetails { orders { id } }" + wantOps := []storage.SessionOp{{ + Name: storage.ShortSHA(body), + Body: body, + Kind: storage.OperationKindQuery, + DocumentName: "GetCustomerContractDetails", + Description: "Fetch contract details.", + }} + assert.Equal(t, wantOps, store.opsSnapshot("session-1")) +} + +func TestHandleSearchStatefulForwardsUnsatisfiedAndTruncated(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{{ + OperationName: "getOrders", + Document: "query GetOrders { orders { id } }", + OperationType: "query", + Description: "Fetch orders.", + }}, + Unsatisfied: []*yokov1.Unsatisfied{ + {Reason: "no field on the schema carries that filter dimension"}, + {Reason: "customer filter not supported"}, + }, + Truncated: true, + } + store := newSearchTestStorage(t) + srv := newSearchTestServer(t, false, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders", "filtered orders"}, + })) + + require.NoError(t, err) + getOrdersBody := "query GetOrders { orders { id } }" + wantOps := []storage.SessionOp{{ + Name: storage.ShortSHA(getOrdersBody), + Body: getOrdersBody, + Kind: storage.OperationKindQuery, Description: "Fetch orders.", + }} + wantFragment, err := tsgen.NewOpsFragment(wantOps, searchHandlerTestSchema(t)) + require.NoError(t, err) + wantText := "// unsatisfied: yoko could not satisfy the following requirement(s):\n" + + "// - no field on the schema carries that filter dimension\n" + + "// - customer filter not supported\n" + + "// truncated: yoko ran out of turns before committing every requirement; consider tightening the prompt.\n" + + "\n" + wantFragment + assert.Equal(t, textToolResult(wantText), got) +} + +func TestHandleSearchStatefulNoOpsWithUnsatisfiedReturnsNotes(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.Resolution{ + Unsatisfied: []*yokov1.Unsatisfied{{Reason: "not possible"}}, + } + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + wantText := "// unsatisfied: yoko could not satisfy the following requirement(s):\n" + + "// - not possible\n" + + noQueriesFromYokoMessage + assert.Equal(t, textToolResult(wantText), got) +} + +func TestHandleSearchStatelessForwardsUnsatisfiedAndTruncated(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{{ + OperationName: "getOrders", + Document: "query GetOrders { orders { id } }", + OperationType: "query", + Description: "Fetch orders.", + }}, + Unsatisfied: []*yokov1.Unsatisfied{{Reason: "no field for that filter"}}, + Truncated: true, + } + srv := newSearchTestServer(t, true, searcher, newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + getOrdersBody := "query GetOrders { orders { id } }" + expectedJSON := mustJSON(t, legacyCatalogueResponse{ + Operations: []legacyCatalogueOperation{{ + Name: storage.ShortSHA(getOrdersBody), + Body: getOrdersBody, + Kind: "Query", + Description: "Fetch orders.", + }}, + Unsatisfied: []string{"no field for that filter"}, + Truncated: true, + }) + assert.Equal(t, textToolResult(expectedJSON), got) +} + +func TestHandleSearchFallsBackToStatelessWhenSessionIDMissing(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.Resolution{Queries: []*yokov1.ResolvedQuery{{ + OperationName: "getOrders", + Document: "query GetOrders { orders { id } }", + OperationType: "query", + Description: "Fetch orders.", }}} store := newSearchTestStorage(t) srv := newSearchTestServer(t, false, searcher, store) @@ -202,30 +340,37 @@ func TestHandleSearchFallsBackToStatelessWhenSessionIDMissing(t *testing.T) { })) require.NoError(t, err) - expectedJSON := mustJSON(t, []legacyCatalogueEntry{{ - Name: "getOrders", - Body: "query GetOrders { orders { id } }", - Kind: "Query", - Description: "Fetch orders.", - Variables: nil, - }}) + getOrdersBody := "query GetOrders { orders { id } }" + expectedJSON := mustJSON(t, legacyCatalogueResponse{ + Operations: []legacyCatalogueOperation{{ + Name: storage.ShortSHA(getOrdersBody), + Body: getOrdersBody, + Kind: "Query", + Description: "Fetch orders.", + }}, + }) assert.Equal(t, textToolResult(expectedJSON), got) assert.Equal(t, []storage.SessionOp(nil), store.opsSnapshot("session-1")) } -func TestHandleSearchNamingCollisionUsesFinalStoredName(t *testing.T) { +func TestHandleSearchSameDocumentNameDifferentBodiesRegistersBoth(t *testing.T) { + // Regression: yoko regenerates the same document name with a different + // body. With SHA-based identity each body lands as its own entry — + // previously the new body was silently dropped under the old name. searcher := newFakeYoko() - searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ - Name: "getOrders", - Body: "query GetOrdersAgain { orders { total } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, - Description: "Fetch order totals.", + searcher.responses <- &yokov1.Resolution{Queries: []*yokov1.ResolvedQuery{{ + OperationName: "getOrders", + Document: "query getOrders { orders { id total } }", + OperationType: "query", + Description: "Fetch order totals.", }}} store := newSearchTestStorage(t) + originalBody := "query getOrders { orders { id } }" _, err := store.Append(context.Background(), "session-1", []storage.SessionOp{{ - Name: "getOrders", - Body: "query GetOrders { orders { id } }", - Kind: storage.OperationKindQuery, + Name: storage.ShortSHA(originalBody), + Body: originalBody, + Kind: storage.OperationKindQuery, + DocumentName: "getOrders", }}) require.NoError(t, err) srv := newSearchTestServer(t, false, searcher, store) @@ -235,19 +380,72 @@ func TestHandleSearchNamingCollisionUsesFinalStoredName(t *testing.T) { })) require.NoError(t, err) - wantOps := []storage.SessionOp{ - {Name: "getOrders", Body: "query GetOrders { orders { id } }", Kind: storage.OperationKindQuery}, - {Name: "getOrders_2", Body: "query GetOrdersAgain { orders { total } }", Kind: storage.OperationKindQuery, Description: "Fetch order totals."}, + newBody := "query getOrders { orders { id total } }" + newOp := storage.SessionOp{ + Name: storage.ShortSHA(newBody), + Body: newBody, + Kind: storage.OperationKindQuery, + DocumentName: "getOrders", + Description: "Fetch order totals.", } - wantFragment, err := tsgen.NewOpsFragment(wantOps[1:], searchHandlerTestSchema(t)) + wantFragment, err := tsgen.NewOpsFragment([]storage.SessionOp{newOp}, searchHandlerTestSchema(t)) + require.NoError(t, err) + assert.Equal(t, textToolResult(wantFragment), got) + assert.Equal(t, []storage.SessionOp{ + { + Name: storage.ShortSHA(originalBody), + Body: originalBody, + Kind: storage.OperationKindQuery, + DocumentName: "getOrders", + }, + newOp, + }, store.opsSnapshot("session-1")) +} + +func TestHandleSearchExistingOpsAreReRenderedOnRepeatPrompt(t *testing.T) { + // Regression for the fresh-context bug: when yoko returns ops that the + // session already has, the handler must still emit their TS declarations + // so a fresh model context can use them without introspecting `tools`. + body := "query GetOrders { orders { id } }" + sha := storage.ShortSHA(body) + + searcher := newFakeYoko() + searcher.responses <- &yokov1.Resolution{Queries: []*yokov1.ResolvedQuery{{ + OperationName: "GetOrders", + Document: body, + OperationType: "query", + Description: "Fetch orders.", + }}} + store := newSearchTestStorage(t) + _, err := store.Append(context.Background(), "session-1", []storage.SessionOp{{ + Name: sha, + Body: body, + Kind: storage.OperationKindQuery, + DocumentName: "GetOrders", + }}) + require.NoError(t, err) + srv := newSearchTestServer(t, false, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + wantOps := []storage.SessionOp{{ + Name: sha, + Body: body, + Kind: storage.OperationKindQuery, + DocumentName: "GetOrders", + }} + wantFragment, err := tsgen.NewOpsFragment(wantOps, searchHandlerTestSchema(t)) require.NoError(t, err) assert.Equal(t, textToolResult(wantFragment), got) assert.Equal(t, wantOps, store.opsSnapshot("session-1")) } -func TestHandleSearchEmptyYokoResponseIsSuccess(t *testing.T) { +func TestHandleSearchEmptyYokoResponseReturnsNoQueriesMessage(t *testing.T) { searcher := newFakeYoko() - searcher.responses <- &yokov1.SearchResponse{} + searcher.responses <- &yokov1.Resolution{} srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ @@ -255,7 +453,7 @@ func TestHandleSearchEmptyYokoResponseIsSuccess(t *testing.T) { })) require.NoError(t, err) - assert.Equal(t, textToolResult(emptySearchMessage), got) + assert.Equal(t, textToolResult(noQueriesFromYokoMessage), got) } func TestHandleSearchDoesNotRetryNotFoundFromSearcher(t *testing.T) { @@ -289,10 +487,10 @@ func TestHandleSearchSingleFlight(t *testing.T) { t.Run("identical calls share leader result", func(t *testing.T) { searcher := newFakeYoko() searcher.block = make(chan struct{}) - searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ - Name: "getOrders", - Body: "query GetOrders { orders { id } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + searcher.responses <- &yokov1.Resolution{Queries: []*yokov1.ResolvedQuery{{ + OperationName: "getOrders", + Document: "query GetOrders { orders { id } }", + OperationType: "query", }}} srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) @@ -328,8 +526,8 @@ func TestHandleSearchSingleFlight(t *testing.T) { t.Run("different calls do not share result", func(t *testing.T) { searcher := newFakeYoko() - searcher.responses <- &yokov1.SearchResponse{} - searcher.responses <- &yokov1.SearchResponse{} + searcher.responses <- &yokov1.Resolution{} + searcher.responses <- &yokov1.Resolution{} srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) var wg sync.WaitGroup @@ -351,8 +549,8 @@ func TestHandleSearchSingleFlight(t *testing.T) { t.Run("ambiguous spacing prompt sets do not share result", func(t *testing.T) { searcher := newFakeYoko() searcher.block = make(chan struct{}) - searcher.responses <- &yokov1.SearchResponse{} - searcher.responses <- &yokov1.SearchResponse{} + searcher.responses <- &yokov1.Resolution{} + searcher.responses <- &yokov1.Resolution{} srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) var wg sync.WaitGroup @@ -381,13 +579,13 @@ func TestHandleSearchSingleFlight(t *testing.T) { func TestHandleSearchRenderErrorIsToolError(t *testing.T) { searcher := newFakeYoko() - searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ - Name: "getOrders", - Body: "query GetOrders { orders { id } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + searcher.responses <- &yokov1.Resolution{Queries: []*yokov1.ResolvedQuery{{ + OperationName: "getOrders", + Document: "query GetOrders { orders { id } }", + OperationType: "query", }}} srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) - srv.newOpsFragment = func([]storage.SessionOp, *ast.Document) (string, error) { + srv.opsFragment = func([]storage.SessionOp, *ast.Document) (string, error) { return "", errors.New("render exploded") } @@ -396,7 +594,7 @@ func TestHandleSearchRenderErrorIsToolError(t *testing.T) { })) require.NoError(t, err) - assert.Equal(t, toolError("code_mode_search_tools: failed to render new ops: render exploded"), got) + assert.Equal(t, toolError("code_mode_search_tools: failed to render ops: render exploded"), got) } func TestHandleSearchCancelMaySurfaceLeaderCancellationToFollower(t *testing.T) { @@ -440,29 +638,30 @@ func TestHandleSearchCancelMaySurfaceLeaderCancellationToFollower(t *testing.T) } type searchCall struct { - sessionID string - prompts []string + prompts []string } type fakeYoko struct { - mu sync.Mutex - calls []searchCall - responses chan *yokov1.SearchResponse - errs chan error - block chan struct{} - schema string + mu sync.Mutex + calls []searchCall + responses chan *yokov1.Resolution + errs chan error + block chan struct{} + schema string + ensureIndexedCalled int + ensureIndexedErr error } func newFakeYoko() *fakeYoko { return &fakeYoko{ - responses: make(chan *yokov1.SearchResponse, 16), + responses: make(chan *yokov1.Resolution, 16), errs: make(chan error, 16), } } -func (f *fakeYoko) Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { +func (f *fakeYoko) Search(ctx context.Context, prompts []string) (*yokov1.Resolution, error) { f.mu.Lock() - f.calls = append(f.calls, searchCall{sessionID: sessionID, prompts: append([]string(nil), prompts...)}) + f.calls = append(f.calls, searchCall{prompts: append([]string(nil), prompts...)}) f.mu.Unlock() if f.block != nil { @@ -482,7 +681,7 @@ func (f *fakeYoko) Search(ctx context.Context, sessionID string, prompts []strin case response := <-f.responses: return response, nil default: - return &yokov1.SearchResponse{}, nil + return &yokov1.Resolution{}, nil } } @@ -498,6 +697,22 @@ func (f *fakeYoko) Schema() string { return f.schema } +// EnsureIndexed records that eager warm-up was requested and returns the +// stubbed ensureIndexedErr (nil by default). The fake does not model an +// index cache; the body is otherwise a no-op. +func (f *fakeYoko) EnsureIndexed(context.Context) error { + f.mu.Lock() + defer f.mu.Unlock() + f.ensureIndexedCalled++ + return f.ensureIndexedErr +} + +func (f *fakeYoko) ensureIndexedCallCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.ensureIndexedCalled +} + func (f *fakeYoko) callCount() int { f.mu.Lock() defer f.mu.Unlock() @@ -509,7 +724,7 @@ func (f *fakeYoko) callsSnapshot() []searchCall { defer f.mu.Unlock() calls := make([]searchCall, 0, len(f.calls)) for _, call := range f.calls { - calls = append(calls, searchCall{sessionID: call.sessionID, prompts: append([]string(nil), call.prompts...)}) + calls = append(calls, searchCall{prompts: append([]string(nil), call.prompts...)}) } return calls } @@ -535,19 +750,23 @@ func (s *searchTestStorage) Append(ctx context.Context, sessionID string, ops [] s.mu.Lock() defer s.mu.Unlock() - taken := make(map[string]struct{}, len(s.ops[sessionID])+len(ops)) - for _, op := range s.ops[sessionID] { - taken[op.Name] = struct{}{} + byBody := make(map[string]storage.SessionOp, len(s.ops[sessionID])+len(ops)) + for _, existing := range s.ops[sessionID] { + byBody[storage.CanonicalBody(existing.Body)] = existing } - appended := make([]storage.SessionOp, 0, len(ops)) + resolved := make([]storage.SessionOp, 0, len(ops)) for _, op := range ops { - op.Name = storage.SuffixedName(storage.NormalizeName(op.Name), taken) - taken[op.Name] = struct{}{} + canonical := storage.CanonicalBody(op.Body) + if existing, ok := byBody[canonical]; ok { + resolved = append(resolved, existing) + continue + } s.ops[sessionID] = append(s.ops[sessionID], op) - appended = append(appended, op) + byBody[canonical] = op + resolved = append(resolved, op) } - return appended, nil + return resolved, nil } func (s *searchTestStorage) GetOp(_ context.Context, sessionID string, name string) (storage.SessionOp, bool, error) { @@ -608,14 +827,6 @@ func (s *searchTestStorage) opsSnapshot(sessionID string) []storage.SessionOp { return append([]storage.SessionOp(nil), s.ops[sessionID]...) } -type legacyCatalogueEntry struct { - Name string `json:"name"` - Body string `json:"body"` - Kind string `json:"kind"` - Description string `json:"description"` - Variables *string `json:"variables"` -} - func newSearchTestServer(t *testing.T, stateless bool, searcher *fakeYoko, store *searchTestStorage) *Server { t.Helper() srv, err := New(Config{ @@ -650,10 +861,6 @@ func textToolResult(text string) *mcp.CallToolResult { } } -func ptrString(value string) *string { - return &value -} - func searchHandlerTestSchema(t *testing.T) *ast.Document { t.Helper() doc, report := astparser.ParseGraphqlDocumentString(searchHandlerTestSchemaSDL) diff --git a/router/internal/codemode/storage/memory_backend.go b/router/internal/codemode/storage/memory_backend.go index 7467d17635..4de8f81c4b 100644 --- a/router/internal/codemode/storage/memory_backend.go +++ b/router/internal/codemode/storage/memory_backend.go @@ -83,6 +83,15 @@ func NewMemoryBackend(config MemoryConfig) *MemoryBackend { } } +// Append resolves each input op against the session and returns one +// SessionOp per input — either the freshly-registered op or the +// pre-existing op it was deduped against. The model receives declarations +// for every matched op regardless of whether it was new, so a fresh +// context never has to introspect the session to discover prior ops. +// +// Operation identity is op.Name (a content-derived SHA produced by the +// caller via ShortSHA). Body and Name are 1:1 — the same body always +// hashes to the same Name, so dedup is a single CanonicalBody check. func (b *MemoryBackend) Append(ctx context.Context, sessionID string, ops []SessionOp) ([]SessionOp, error) { if err := ctx.Err(); err != nil { return nil, err @@ -93,24 +102,37 @@ func (b *MemoryBackend) Append(ctx context.Context, sessionID string, ops []Sess session := b.loadOrCreateSession(sessionID) session.mu.Lock() - appended := make([]SessionOp, 0, len(ops)) - taken := make(map[string]struct{}, len(session.ops)+len(ops)) - for _, op := range session.ops { - taken[op.Name] = struct{}{} + + byBody := make(map[string]SessionOp, len(session.ops)+len(ops)) + for _, existing := range session.ops { + byBody[CanonicalBody(existing.Body)] = existing } + + resolved := make([]SessionOp, 0, len(ops)) + appendedAny := false for _, op := range ops { - op.Name = SuffixedName(NormalizeName(op.Name), taken) - taken[op.Name] = struct{}{} + canonical := CanonicalBody(op.Body) + if existing, ok := byBody[canonical]; ok { + resolved = append(resolved, existing) + continue + } session.ops = append(session.ops, op) - appended = append(appended, op) + byBody[canonical] = op + resolved = append(resolved, op) + appendedAny = true + } + + if appendedAny { + session.lastUsed = b.now() + session.bundle = "" + session.bundleValid = false } - session.lastUsed = b.now() - session.bundle = "" - session.bundleValid = false session.mu.Unlock() - b.enforceMaxSessions() - return appended, nil + if appendedAny { + b.enforceMaxSessions() + } + return resolved, nil } func (b *MemoryBackend) GetOp(ctx context.Context, sessionID string, name string) (SessionOp, bool, error) { diff --git a/router/internal/codemode/storage/memory_backend_test.go b/router/internal/codemode/storage/memory_backend_test.go index 662816ce64..e4fdf1ce01 100644 --- a/router/internal/codemode/storage/memory_backend_test.go +++ b/router/internal/codemode/storage/memory_backend_test.go @@ -3,7 +3,6 @@ package storage import ( "context" "fmt" - "sort" "strings" "sync" "testing" @@ -68,41 +67,36 @@ func TestMemoryBackendAppendGetOpBundleResetRoundTrip(t *testing.T) { clock := newTestClock() backend := newTestBackend(t, clock, nil) + queryBody := "query GetUser { user { id } }" + mutationBody := "mutation DeleteUser { deleteUser(id: 1) }" + querySHA := ShortSHA(queryBody) + mutationSHA := ShortSHA(mutationBody) + ops := []SessionOp{ - {Name: "get-user", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, - {Name: "delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, - {Name: "get-user", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + {Name: querySHA, Body: queryBody, Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: mutationSHA, Body: mutationBody, Kind: OperationKindMutation, Description: "Delete a user"}, } appended, err := backend.Append(ctx, "session-1", ops) require.NoError(t, err) - assert.Equal(t, []SessionOp{ - {Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, - {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, - {Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, - }, appended) - - gotQuery, ok, err := backend.GetOp(ctx, "session-1", "getUser") - require.NoError(t, err) - assert.Equal(t, true, ok) - assert.Equal(t, SessionOp{Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, gotQuery) + assert.Equal(t, ops, appended) - gotMutation, ok, err := backend.GetOp(ctx, "session-1", "op_delete") + gotQuery, ok, err := backend.GetOp(ctx, "session-1", querySHA) require.NoError(t, err) assert.Equal(t, true, ok) - assert.Equal(t, SessionOp{Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, gotMutation) + assert.Equal(t, ops[0], gotQuery) - gotCollision, ok, err := backend.GetOp(ctx, "session-1", "getUser_2") + gotMutation, ok, err := backend.GetOp(ctx, "session-1", mutationSHA) require.NoError(t, err) assert.Equal(t, true, ok) - assert.Equal(t, SessionOp{Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, gotCollision) + assert.Equal(t, ops[1], gotMutation) bundle, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) - assert.Equal(t, "getUser\nop_delete\ngetUser_2", bundle) + assert.Equal(t, querySHA+"\n"+mutationSHA, bundle) require.NoError(t, backend.Reset(ctx, "session-1")) - gotAfterReset, ok, err := backend.GetOp(ctx, "session-1", "getUser") + gotAfterReset, ok, err := backend.GetOp(ctx, "session-1", querySHA) require.NoError(t, err) assert.Equal(t, false, ok) assert.Equal(t, SessionOp{}, gotAfterReset) @@ -112,13 +106,114 @@ func TestMemoryBackendAppendGetOpBundleResetRoundTrip(t *testing.T) { assert.Equal(t, "", bundleAfterReset) } +func TestMemoryBackendAppendIdempotentOnIdenticalBody(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + + body := "query GetUser { user { id } }" + sha := ShortSHA(body) + + first, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, Description: "v1"}, + }) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, Description: "v1"}, + }, first) + + // Whitespace-only differences canonicalize to the same SHA, so the + // backend reuses the first registration. + second, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: sha, Body: " query GetUser {\n user { id }\n}\n", Kind: OperationKindQuery, Description: "v2"}, + }) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, Description: "v1"}, + }, second) + + names, err := backend.ListNames(ctx, "s1") + require.NoError(t, err) + assert.Equal(t, []string{sha}, names) + + got, ok, err := backend.GetOp(ctx, "s1", sha) + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: sha, Body: body, Kind: OperationKindQuery, Description: "v1"}, got) +} + +func TestMemoryBackendAppendDedupsBodyAcrossPromptNames(t *testing.T) { + // Regression: yoko sometimes returns the same body under different + // document names ("getUser" via one prompt, "fetchUser" via another). + // Storage dedups by canonical body, so the second registration reuses + // the first SessionOp regardless of the inbound DocumentName. + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + + body := "query GetUser { user { id } }" + sha := ShortSHA(body) + + _, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }) + require.NoError(t, err) + + second, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, DocumentName: "FetchUser"}, + }) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }, second) + + names, err := backend.ListNames(ctx, "s1") + require.NoError(t, err) + assert.Equal(t, []string{sha}, names) +} + +func TestMemoryBackendAppendDifferentBodiesGetSeparateEntries(t *testing.T) { + // Regression: yoko regenerates an operation under the same document + // name but with a different body. With SHA-based identity each body + // gets its own entry, eliminating the silent overwrite that name-based + // identity used to mask. + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + + bodyV1 := "query GetUser { user { id } }" + bodyV2 := "query GetUser { user { name } }" + shaV1 := ShortSHA(bodyV1) + shaV2 := ShortSHA(bodyV2) + require.NotEqual(t, shaV1, shaV2) + + _, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: shaV1, Body: bodyV1, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }) + require.NoError(t, err) + + resolved, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: shaV2, Body: bodyV2, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: shaV2, Body: bodyV2, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }, resolved) + + names, err := backend.ListNames(ctx, "s1") + require.NoError(t, err) + assert.ElementsMatch(t, []string{shaV1, shaV2}, names) +} + func TestMemoryBackendSetSchemaClearsSessionsAndIncrementsSchemaVersion(t *testing.T) { ctx := context.Background() clock := newTestClock() backend := newTestBackend(t, clock, nil) initialVersion := backend.SchemaVersion() - _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "get-user", Body: "query { user { id } }", Kind: OperationKindQuery}}) + body := "query { user { id } }" + sha := ShortSHA(body) + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: sha, Body: body, Kind: OperationKindQuery}}) require.NoError(t, err) schema := &ast.Document{} @@ -127,7 +222,7 @@ func TestMemoryBackendSetSchemaClearsSessionsAndIncrementsSchemaVersion(t *testi assert.Equal(t, initialVersion+1, backend.SchemaVersion()) assert.Equal(t, schema, backend.Schema()) - got, ok, err := backend.GetOp(ctx, "session-1", "getUser") + got, ok, err := backend.GetOp(ctx, "session-1", sha) require.NoError(t, err) assert.Equal(t, false, ok) assert.Equal(t, SessionOp{}, got) @@ -147,23 +242,28 @@ func TestMemoryBackendTTLEvictionUsesInjectedClock(t *testing.T) { Now: clock.Now, }) - _, err := backend.Append(ctx, "idle", []SessionOp{{Name: "idle-op", Body: "query { idle }", Kind: OperationKindQuery}}) + idleBody := "query { idle }" + freshBody := "query { fresh }" + idleSHA := ShortSHA(idleBody) + freshSHA := ShortSHA(freshBody) + + _, err := backend.Append(ctx, "idle", []SessionOp{{Name: idleSHA, Body: idleBody, Kind: OperationKindQuery}}) require.NoError(t, err) - _, err = backend.Append(ctx, "fresh", []SessionOp{{Name: "fresh-op", Body: "query { fresh }", Kind: OperationKindQuery}}) + _, err = backend.Append(ctx, "fresh", []SessionOp{{Name: freshSHA, Body: freshBody, Kind: OperationKindQuery}}) require.NoError(t, err) clock.Advance(30 * time.Second) - _, ok, err := backend.GetOp(ctx, "fresh", "freshOp") + _, ok, err := backend.GetOp(ctx, "fresh", freshSHA) require.NoError(t, err) assert.Equal(t, true, ok) clock.Advance(31 * time.Second) backend.sweepIdle() - _, idleOK, err := backend.GetOp(ctx, "idle", "idleOp") + _, idleOK, err := backend.GetOp(ctx, "idle", idleSHA) require.NoError(t, err) assert.Equal(t, false, idleOK) - _, freshOK, err := backend.GetOp(ctx, "fresh", "freshOp") + _, freshOK, err := backend.GetOp(ctx, "fresh", freshSHA) require.NoError(t, err) assert.Equal(t, true, freshOK) } @@ -179,86 +279,78 @@ func TestMemoryBackendLRUEvictionAtMaxSessions(t *testing.T) { Now: clock.Now, }) - _, err := backend.Append(ctx, "session-a", []SessionOp{{Name: "a-op", Body: "query { a }", Kind: OperationKindQuery}}) + aBody := "query { a }" + bBody := "query { b }" + cBody := "query { c }" + aSHA := ShortSHA(aBody) + bSHA := ShortSHA(bBody) + cSHA := ShortSHA(cBody) + + _, err := backend.Append(ctx, "session-a", []SessionOp{{Name: aSHA, Body: aBody, Kind: OperationKindQuery}}) require.NoError(t, err) clock.Advance(time.Second) - _, err = backend.Append(ctx, "session-b", []SessionOp{{Name: "b-op", Body: "query { b }", Kind: OperationKindQuery}}) + _, err = backend.Append(ctx, "session-b", []SessionOp{{Name: bSHA, Body: bBody, Kind: OperationKindQuery}}) require.NoError(t, err) clock.Advance(time.Second) - _, ok, err := backend.GetOp(ctx, "session-a", "aOp") + _, ok, err := backend.GetOp(ctx, "session-a", aSHA) require.NoError(t, err) assert.Equal(t, true, ok) clock.Advance(time.Second) - _, err = backend.Append(ctx, "session-c", []SessionOp{{Name: "c-op", Body: "query { c }", Kind: OperationKindQuery}}) + _, err = backend.Append(ctx, "session-c", []SessionOp{{Name: cSHA, Body: cBody, Kind: OperationKindQuery}}) require.NoError(t, err) - _, aOK, err := backend.GetOp(ctx, "session-a", "aOp") + _, aOK, err := backend.GetOp(ctx, "session-a", aSHA) require.NoError(t, err) assert.Equal(t, true, aOK) - _, bOK, err := backend.GetOp(ctx, "session-b", "bOp") + _, bOK, err := backend.GetOp(ctx, "session-b", bSHA) require.NoError(t, err) assert.Equal(t, false, bOK) - _, cOK, err := backend.GetOp(ctx, "session-c", "cOp") + _, cOK, err := backend.GetOp(ctx, "session-c", cSHA) require.NoError(t, err) assert.Equal(t, true, cOK) } -func TestMemoryBackendConcurrentAppendIsRaceFreeAndSuffixesNames(t *testing.T) { +func TestMemoryBackendConcurrentAppendSameBodyConvergesToOne(t *testing.T) { ctx := context.Background() clock := newTestClock() backend := newTestBackend(t, clock, nil) const goroutines = 32 + body := "query Shared { shared }" + sha := ShortSHA(body) + var wg sync.WaitGroup - errs := make(chan error, goroutines) + results := make(chan []SessionOp, goroutines) for i := range goroutines { wg.Add(1) go func(i int) { defer wg.Done() - _, err := backend.Append(ctx, "shared", []SessionOp{{ - Name: "shared-op", - Body: fmt.Sprintf("query Shared%d { shared%d }", i, i), + resolved, err := backend.Append(ctx, "shared", []SessionOp{{ + Name: sha, + Body: body, Kind: OperationKindQuery, Description: fmt.Sprintf("Shared %d", i), }}) - errs <- err + require.NoError(t, err) + results <- resolved }(i) } wg.Wait() - close(errs) + close(results) - for err := range errs { - require.NoError(t, err) + for resolved := range results { + require.Equal(t, 1, len(resolved)) + assert.Equal(t, sha, resolved[0].Name) } - names := make([]string, 0, goroutines) - for i := range goroutines { - name := "sharedOp" - if i > 0 { - name = fmt.Sprintf("sharedOp_%d", i+1) - } - op, ok, err := backend.GetOp(ctx, "shared", name) - require.NoError(t, err) - assert.Equal(t, true, ok) - names = append(names, op.Name) - } - - sort.Strings(names) - want := make([]string, 0, goroutines) - for i := range goroutines { - name := "sharedOp" - if i > 0 { - name = fmt.Sprintf("sharedOp_%d", i+1) - } - want = append(want, name) - } - sort.Strings(want) - assert.Equal(t, want, names) + names, err := backend.ListNames(ctx, "shared") + require.NoError(t, err) + assert.Equal(t, []string{sha}, names) } func TestMemoryBackendBundleCacheInvalidatesOnAppend(t *testing.T) { @@ -279,26 +371,31 @@ func TestMemoryBackendBundleCacheInvalidatesOnAppend(t *testing.T) { }) backend := newTestBackend(t, clock, renderer) - _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "one", Body: "query { one }", Kind: OperationKindQuery}}) + oneBody := "query { one }" + twoBody := "query { two }" + oneSHA := ShortSHA(oneBody) + twoSHA := ShortSHA(twoBody) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: oneSHA, Body: oneBody, Kind: OperationKindQuery}}) require.NoError(t, err) first, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) - assert.Equal(t, "one", first) + assert.Equal(t, oneSHA, first) second, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) - assert.Equal(t, "one", second) + assert.Equal(t, oneSHA, second) - _, err = backend.Append(ctx, "session-1", []SessionOp{{Name: "two", Body: "query { two }", Kind: OperationKindQuery}}) + _, err = backend.Append(ctx, "session-1", []SessionOp{{Name: twoSHA, Body: twoBody, Kind: OperationKindQuery}}) require.NoError(t, err) third, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) - assert.Equal(t, "one,two", third) + assert.Equal(t, oneSHA+","+twoSHA, third) mu.Lock() gotRendered := append([]string(nil), rendered...) mu.Unlock() - assert.Equal(t, []string{"one", "one,two"}, gotRendered) + assert.Equal(t, []string{oneSHA, oneSHA + "," + twoSHA}, gotRendered) } func TestMemoryBackendBundleDropsWholeOpsAtMaxBundleBytes(t *testing.T) { @@ -311,22 +408,31 @@ func TestMemoryBackendBundleDropsWholeOpsAtMaxBundleBytes(t *testing.T) { } return strings.Join(names, "|"), nil }) + + oneBody := "query { one }" + twoBody := "query { two }" + threeBody := "query { three }" + oneSHA := ShortSHA(oneBody) + twoSHA := ShortSHA(twoBody) + threeSHA := ShortSHA(threeBody) + twoOpsBundle := oneSHA + "|" + twoSHA + backend := NewMemoryBackend(MemoryConfig{ SessionTTL: time.Hour, MaxSessions: 100, - MaxBundleBytes: len("one|two"), + MaxBundleBytes: len(twoOpsBundle), Renderer: renderer, Now: clock.Now, }) _, err := backend.Append(ctx, "session-1", []SessionOp{ - {Name: "one", Body: "query { one }", Kind: OperationKindQuery}, - {Name: "two", Body: "query { two }", Kind: OperationKindQuery}, - {Name: "three", Body: "query { three }", Kind: OperationKindQuery}, + {Name: oneSHA, Body: oneBody, Kind: OperationKindQuery}, + {Name: twoSHA, Body: twoBody, Kind: OperationKindQuery}, + {Name: threeSHA, Body: threeBody, Kind: OperationKindQuery}, }) require.NoError(t, err) bundle, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) - assert.Equal(t, "one|two", bundle) + assert.Equal(t, twoOpsBundle, bundle) } diff --git a/router/internal/codemode/storage/naming.go b/router/internal/codemode/storage/naming.go index a3b91eadc9..6c202b6cb4 100644 --- a/router/internal/codemode/storage/naming.go +++ b/router/internal/codemode/storage/naming.go @@ -1,191 +1,35 @@ package storage import ( - "slices" - "strconv" + "crypto/sha256" + "encoding/hex" "strings" - "unicode" ) -var reservedWords = map[string]struct{}{ - "abstract": {}, - "any": {}, - "as": {}, - "async": {}, - "await": {}, - "boolean": {}, - "break": {}, - "case": {}, - "catch": {}, - "class": {}, - "const": {}, - "constructor": {}, - "continue": {}, - "debugger": {}, - "declare": {}, - "default": {}, - "delete": {}, - "do": {}, - "else": {}, - "enum": {}, - "export": {}, - "extends": {}, - "false": {}, - "finally": {}, - "for": {}, - "from": {}, - "function": {}, - "get": {}, - "if": {}, - "implements": {}, - "import": {}, - "in": {}, - "infer": {}, - "instanceof": {}, - "interface": {}, - "is": {}, - "keyof": {}, - "let": {}, - "module": {}, - "namespace": {}, - "never": {}, - "new": {}, - "null": {}, - "number": {}, - "object": {}, - "of": {}, - "package": {}, - "private": {}, - "protected": {}, - "public": {}, - "readonly": {}, - "require": {}, - "return": {}, - "satisfies": {}, - "set": {}, - "static": {}, - "string": {}, - "super": {}, - "switch": {}, - "symbol": {}, - "this": {}, - "throw": {}, - "true": {}, - "try": {}, - "type": {}, - "typeof": {}, - "undefined": {}, - "unique": {}, - "unknown": {}, - "var": {}, - "void": {}, - "while": {}, - "with": {}, - "yield": {}, -} - -func NormalizeName(raw string) string { - // Idempotency: names produced by an earlier NormalizeName call (carrying our reserved-word - // or leading-digit prefixes) round-trip without re-splitting. - if rest, ok := strings.CutPrefix(raw, "op_"); ok { - if _, reserved := reservedWords[rest]; reserved && isLowerCamel(rest) { - return raw - } - } - if rest, ok := strings.CutPrefix(raw, "_"); ok { - if len(rest) > 0 && unicode.IsDigit(rune(rest[0])) && isIdentTail(rest) { - return raw - } - } - words := strings.FieldsFunc(raw, func(r rune) bool { - return !unicode.IsLetter(r) && !unicode.IsDigit(r) - }) - words = slices.DeleteFunc(words, func(word string) bool { - return word == "" - }) - if len(words) == 0 { - return "operation" - } - - var builder strings.Builder - for i, word := range words { - if i == 0 { - builder.WriteString(lowerFirst(word)) - continue - } - builder.WriteString(upperFirst(word)) - } - - name := builder.String() - if name == "" { - name = "operation" - } - if first, _ := firstRune(name); unicode.IsDigit(first) { - name = "_" + name - } - if _, ok := reservedWords[name]; ok { - name = "op_" + name - } - return name -} - -func SuffixedName(base string, taken map[string]struct{}) string { - if _, ok := taken[base]; !ok { - return base - } - for i := 2; ; i++ { - name := base + "_" + strconv.Itoa(i) - if _, ok := taken[name]; !ok { - return name - } - } -} - -func lowerFirst(value string) string { - if value == "" { - return value - } - runes := []rune(value) - runes[0] = unicode.ToLower(runes[0]) - return string(runes) -} - -func upperFirst(value string) string { - if value == "" { - return value - } - runes := []rune(strings.ToLower(value)) - runes[0] = unicode.ToUpper(runes[0]) - return string(runes) -} - -func isLowerCamel(value string) bool { - if value == "" { - return false - } - for i, r := range value { - if i == 0 && !unicode.IsLower(r) { - return false - } - if !unicode.IsLetter(r) && !unicode.IsDigit(r) { - return false - } - } - return true -} - -func isIdentTail(value string) bool { - for _, r := range value { - if !unicode.IsLetter(r) && !unicode.IsDigit(r) { - return false - } - } - return true -} - -func firstRune(value string) (rune, bool) { - for _, r := range value { - return r, true - } - return 0, false +// shortSHALen is the number of hex characters from the SHA-256 prefix used to +// derive a stable per-operation identifier. With 8 hex chars (32 bits) the +// birthday-collision probability across a 1k-op session is ~0.012%, which is +// fine for the session-scoped use we have. +const shortSHALen = 8 + +// ShortSHA returns a stable identifier derived from the operation body. The +// result is a valid JavaScript identifier ("o" prefix + lowercase hex), so +// the model can call `tools.(...)` directly without bracket access. +// +// The hash is computed over CanonicalBody(body) so identical operations that +// differ only in whitespace map to the same identifier. This is the key +// invariant: operation identity = operation content. Two operations that +// happen to share a name from yoko but have different bodies get different +// identifiers; two prompts that produce the same body share an identifier. +func ShortSHA(body string) string { + sum := sha256.Sum256([]byte(CanonicalBody(body))) + return "o" + hex.EncodeToString(sum[:])[:shortSHALen] +} + +// CanonicalBody returns a whitespace-normalized form of a GraphQL operation +// body for equality comparison. Two bodies that differ only in formatting +// (newlines, indentation, repeated spaces) compare equal. It does NOT +// canonicalize alias names, argument order, or fragment expansion. +func CanonicalBody(body string) string { + return strings.Join(strings.Fields(body), " ") } diff --git a/router/internal/codemode/storage/naming_test.go b/router/internal/codemode/storage/naming_test.go index 10215f9730..d9e90e2dd8 100644 --- a/router/internal/codemode/storage/naming_test.go +++ b/router/internal/codemode/storage/naming_test.go @@ -6,79 +6,49 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNormalizeName(t *testing.T) { - tests := []struct { - name string - raw string - want string - }{ - {name: "kebab case", raw: "get-user-by-id", want: "getUserById"}, - {name: "snake case", raw: "get_user_by_id", want: "getUserById"}, - {name: "space separated", raw: "Get User By ID", want: "getUserById"}, - {name: "mixed separators", raw: "get__user--by id", want: "getUserById"}, - {name: "already camel", raw: "getUserById", want: "getUserById"}, - {name: "leading digit", raw: "123foo", want: "_123foo"}, - {name: "leading digit with separators", raw: "123-foo-bar", want: "_123FooBar"}, - {name: "reserved word", raw: "delete", want: "op_delete"}, - {name: "reserved word after normalization", raw: "class", want: "op_class"}, - {name: "invalid punctuation", raw: "get$user#by%id", want: "getUserById"}, - {name: "empty input", raw: "", want: "operation"}, - {name: "only invalid input", raw: "$$$", want: "operation"}, - {name: "underscore output for reserved word is not rechecked", raw: "op-delete", want: "opDelete"}, - } +func TestShortSHA(t *testing.T) { + t.Run("identifier shape", func(t *testing.T) { + got := ShortSHA("query GetUser { user { id } }") + assert.Equal(t, "oe4467893", got) + }) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, NormalizeName(tt.raw)) - }) - } + t.Run("whitespace-equivalent bodies share an identifier", func(t *testing.T) { + a := ShortSHA("query GetUser { user { id } }") + b := ShortSHA(" query GetUser {\n user { id }\n}\n") + assert.Equal(t, a, b) + }) + + t.Run("different bodies produce different identifiers", func(t *testing.T) { + a := ShortSHA("query GetUser { user { id } }") + b := ShortSHA("query GetUser { user { name } }") + assert.NotEqual(t, a, b) + }) + + t.Run("same body via different prompt name still maps to same identifier", func(t *testing.T) { + // Regression: yoko returns the same body under "fetchUser" in one + // search and "getUser" in another. The identifier must be the + // content-derived SHA, not the document name. + a := ShortSHA("query GetUser { user { id } }") + b := ShortSHA("query GetUser { user { id } }") + assert.Equal(t, a, b) + }) } -func TestSuffixedName(t *testing.T) { +func TestCanonicalBody(t *testing.T) { tests := []struct { - name string - base string - taken map[string]struct{} - want string + name string + raw string + want string }{ - { - name: "first use keeps base", - base: "getUser", - taken: map[string]struct{}{}, - want: "getUser", - }, - { - name: "first collision uses suffix two", - base: "getUser", - taken: map[string]struct{}{ - "getUser": {}, - }, - want: "getUser_2", - }, - { - name: "skips occupied suffixes", - base: "getUser", - taken: map[string]struct{}{ - "getUser": {}, - "getUser_2": {}, - "getUser_3": {}, - }, - want: "getUser_4", - }, - { - name: "gap is reused", - base: "getUser", - taken: map[string]struct{}{ - "getUser": {}, - "getUser_3": {}, - }, - want: "getUser_2", - }, + {name: "single-line passthrough", raw: "query GetUser { user { id } }", want: "query GetUser { user { id } }"}, + {name: "multi-line collapses", raw: " query GetUser {\n user { id }\n}\n", want: "query GetUser { user { id } }"}, + {name: "tabs normalize", raw: "query\tGetUser\t{ user { id } }", want: "query GetUser { user { id } }"}, + {name: "empty input stays empty", raw: "", want: ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, SuffixedName(tt.base, tt.taken)) + assert.Equal(t, tt.want, CanonicalBody(tt.raw)) }) } } diff --git a/router/internal/codemode/storage/redis_backend.go b/router/internal/codemode/storage/redis_backend.go index 90e6e66883..b8e5f42a1e 100644 --- a/router/internal/codemode/storage/redis_backend.go +++ b/router/internal/codemode/storage/redis_backend.go @@ -78,6 +78,11 @@ func NewRedisBackend(cfg RedisConfig) (*RedisBackend, error) { }, nil } +// Append resolves each input op against the session and returns one +// SessionOp per input — either the freshly-registered op or the +// pre-existing op it was deduped against. See MemoryBackend.Append for +// the resolution rule (CanonicalBody match); this implementation is +// identical apart from running inside a Redis WATCH transaction. func (b *RedisBackend) Append(ctx context.Context, sessionID string, ops []SessionOp) ([]SessionOp, error) { if err := ctx.Err(); err != nil { return nil, err @@ -87,7 +92,7 @@ func (b *RedisBackend) Append(ctx context.Context, sessionID string, ops []Sessi } backoff := 5 * time.Millisecond - var appended []SessionOp + var resolved []SessionOp for { if err := ctx.Err(); err != nil { return nil, err @@ -102,19 +107,31 @@ func (b *RedisBackend) Append(ctx context.Context, sessionID string, ops []Sessi return err } - taken := make(map[string]struct{}, len(entries)+len(ops)) + byBody := make(map[string]SessionOp, len(entries)+len(ops)) for _, entry := range entries { - taken[entry.Name] = struct{}{} + byBody[CanonicalBody(entry.Body)] = entry.SessionOp } - appended = make([]SessionOp, 0, len(ops)) + + currentResolved := make([]SessionOp, 0, len(ops)) + appendedAny := false for _, op := range ops { - op.Name = SuffixedName(NormalizeName(op.Name), taken) - taken[op.Name] = struct{}{} + canonical := CanonicalBody(op.Body) + if existing, ok := byBody[canonical]; ok { + currentResolved = append(currentResolved, existing) + continue + } entries = append(entries, redisOpEntry{ SessionOp: op, LastUsed: now, }) - appended = append(appended, op) + byBody[canonical] = op + currentResolved = append(currentResolved, op) + appendedAny = true + } + + resolved = currentResolved + if !appendedAny { + return nil } payload, err := json.Marshal(entries) if err != nil { @@ -130,7 +147,7 @@ func (b *RedisBackend) Append(ctx context.Context, sessionID string, ops []Sessi return err }, opsKey) if err == nil { - return appended, nil + return resolved, nil } b.logger.Debug("retrying code mode redis append", diff --git a/router/internal/codemode/storage/redis_backend_test.go b/router/internal/codemode/storage/redis_backend_test.go index 3bb736353c..6f33fe8836 100644 --- a/router/internal/codemode/storage/redis_backend_test.go +++ b/router/internal/codemode/storage/redis_backend_test.go @@ -67,28 +67,28 @@ func TestRedisBackendAppendGetOpRoundTrip(t *testing.T) { ctx := context.Background() backend, _, _ := newTestRedisBackend(t, nil, time.Hour) + queryBody := "query GetUser { user { id } }" + mutationBody := "mutation DeleteUser { deleteUser(id: 1) }" + querySHA := ShortSHA(queryBody) + mutationSHA := ShortSHA(mutationBody) + ops := []SessionOp{ - {Name: "get-user", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, - {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, - {Name: "get-user", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + {Name: querySHA, Body: queryBody, Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: mutationSHA, Body: mutationBody, Kind: OperationKindMutation, Description: "Delete a user"}, } appended, err := backend.Append(ctx, "session-1", ops) require.NoError(t, err) - assert.Equal(t, []SessionOp{ - {Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, - {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, - {Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, - }, appended) + assert.Equal(t, ops, appended) - gotQuery, ok, err := backend.GetOp(ctx, "session-1", "getUser") + gotQuery, ok, err := backend.GetOp(ctx, "session-1", querySHA) require.NoError(t, err) assert.Equal(t, true, ok) - assert.Equal(t, SessionOp{Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, gotQuery) + assert.Equal(t, ops[0], gotQuery) - gotCollision, ok, err := backend.GetOp(ctx, "session-1", "getUser_2") + gotMutation, ok, err := backend.GetOp(ctx, "session-1", mutationSHA) require.NoError(t, err) assert.Equal(t, true, ok) - assert.Equal(t, SessionOp{Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, gotCollision) + assert.Equal(t, ops[1], gotMutation) gotMissing, ok, err := backend.GetOp(ctx, "session-1", "missing") require.NoError(t, err) @@ -96,6 +96,89 @@ func TestRedisBackendAppendGetOpRoundTrip(t *testing.T) { assert.Equal(t, SessionOp{}, gotMissing) } +func TestRedisBackendAppendIdempotentOnIdenticalBody(t *testing.T) { + ctx := context.Background() + backend, _, _ := newTestRedisBackend(t, nil, time.Hour) + + body := "query GetUser { user { id } }" + sha := ShortSHA(body) + + first, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, Description: "v1"}, + }) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, Description: "v1"}, + }, first) + + // Whitespace-only differences canonicalize to the same SHA, so the + // backend reuses the first registration. + second, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: sha, Body: " query GetUser {\n user { id }\n}\n", Kind: OperationKindQuery, Description: "v2"}, + }) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, Description: "v1"}, + }, second) + + names, err := backend.ListNames(ctx, "s1") + require.NoError(t, err) + assert.Equal(t, []string{sha}, names) +} + +func TestRedisBackendAppendDedupsBodyAcrossPromptNames(t *testing.T) { + ctx := context.Background() + backend, _, _ := newTestRedisBackend(t, nil, time.Hour) + + body := "query GetUser { user { id } }" + sha := ShortSHA(body) + + _, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }) + require.NoError(t, err) + + second, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, DocumentName: "FetchUser"}, + }) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: sha, Body: body, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }, second) + + names, err := backend.ListNames(ctx, "s1") + require.NoError(t, err) + assert.Equal(t, []string{sha}, names) +} + +func TestRedisBackendAppendDifferentBodiesGetSeparateEntries(t *testing.T) { + ctx := context.Background() + backend, _, _ := newTestRedisBackend(t, nil, time.Hour) + + bodyV1 := "query GetUser { user { id } }" + bodyV2 := "query GetUser { user { name } }" + shaV1 := ShortSHA(bodyV1) + shaV2 := ShortSHA(bodyV2) + require.NotEqual(t, shaV1, shaV2) + + _, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: shaV1, Body: bodyV1, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }) + require.NoError(t, err) + + resolved, err := backend.Append(ctx, "s1", []SessionOp{ + {Name: shaV2, Body: bodyV2, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: shaV2, Body: bodyV2, Kind: OperationKindQuery, DocumentName: "GetUser"}, + }, resolved) + + names, err := backend.ListNames(ctx, "s1") + require.NoError(t, err) + assert.ElementsMatch(t, []string{shaV1, shaV2}, names) +} + func TestRedisBackendBundleRendersAndReadsFromCache(t *testing.T) { ctx := context.Background() var renders atomic.Int64 @@ -104,17 +187,19 @@ func TestRedisBackendBundleRendersAndReadsFromCache(t *testing.T) { return fmt.Sprintf("render-%d:%s", renders.Load(), ops[0].Name), nil }), time.Hour) - _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + body := "query { user { id } }" + sha := ShortSHA(body) + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: sha, Body: body, Kind: OperationKindQuery}}) require.NoError(t, err) first, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) - assert.Equal(t, "render-1:getUser", first) + assert.Equal(t, "render-1:"+sha, first) assert.Equal(t, true, mr.Exists(backend.bundleKey("session-1"))) second, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) - assert.Equal(t, "render-1:getUser", second) + assert.Equal(t, "render-1:"+sha, second) assert.Equal(t, int64(1), renders.Load()) } @@ -122,7 +207,9 @@ func TestRedisBackendResetClearsOpsAndBundleKeys(t *testing.T) { ctx := context.Background() backend, mr, _ := newTestRedisBackend(t, nil, time.Hour) - _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + body := "query { user { id } }" + sha := ShortSHA(body) + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: sha, Body: body, Kind: OperationKindQuery}}) require.NoError(t, err) _, err = backend.Bundle(ctx, "session-1") require.NoError(t, err) @@ -146,7 +233,9 @@ func TestRedisBackendSetSchemaRotatesKeysAndKeepsOldKeysUntilTTL(t *testing.T) { }), time.Hour) backend.SetSchema(schemaA) - _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + body := "query { user { id } }" + sha := ShortSHA(body) + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: sha, Body: body, Kind: OperationKindQuery}}) require.NoError(t, err) oldOpsKey := backend.opsKey("session-1") oldBundleKey := backend.bundleKey("session-1") @@ -164,7 +253,7 @@ func TestRedisBackendSetSchemaRotatesKeysAndKeepsOldKeysUntilTTL(t *testing.T) { assert.Equal(t, true, mr.Exists(oldOpsKey)) assert.Equal(t, true, mr.Exists(oldBundleKey)) - _, err = backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + _, err = backend.Append(ctx, "session-1", []SessionOp{{Name: sha, Body: body, Kind: OperationKindQuery}}) require.NoError(t, err) second, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) @@ -181,13 +270,18 @@ func TestRedisBackendConcurrentAppendRetriesWatchConflicts(t *testing.T) { var wg sync.WaitGroup errs := make(chan error, goroutines) - for i := 0; i < goroutines; i++ { + for i := range goroutines { wg.Add(1) go func(worker int) { defer wg.Done() ops := make([]SessionOp, 0, opsPerGoroutine) - for j := 0; j < opsPerGoroutine; j++ { - ops = append(ops, SessionOp{Name: fmt.Sprintf("op_%02d_%02d", worker, j), Body: "query { ok }", Kind: OperationKindQuery}) + for j := range opsPerGoroutine { + body := fmt.Sprintf("query Q_%02d_%02d { f_%02d_%02d }", worker, j, worker, j) + ops = append(ops, SessionOp{ + Name: ShortSHA(body), + Body: body, + Kind: OperationKindQuery, + }) } _, err := backend.Append(ctx, "session-1", ops) errs <- err @@ -212,7 +306,8 @@ func TestRedisBackendAppendAbandonsOnContextDone(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) defer cancel() - _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + body := "query { user { id } }" + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: ShortSHA(body), Body: body, Kind: OperationKindQuery}}) require.Error(t, err) assert.Equal(t, true, errors.Is(err, context.DeadlineExceeded)) @@ -222,7 +317,8 @@ func TestRedisBackendExpiresKeysOnWrites(t *testing.T) { ctx := context.Background() backend, mr, _ := newTestRedisBackend(t, nil, 10*time.Second) - _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + body := "query { user { id } }" + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: ShortSHA(body), Body: body, Kind: OperationKindQuery}}) require.NoError(t, err) opsKey := backend.opsKey("session-1") assert.Equal(t, 10*time.Second, mr.TTL(opsKey)) @@ -242,7 +338,9 @@ func TestRedisBackendBundleWriteBackIsBestEffort(t *testing.T) { backend, mr, _ := newTestRedisBackend(t, testRedisRenderer(func(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { return "rendered:" + ops[0].Name, nil }), time.Hour) - _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + body := "query { user { id } }" + sha := ShortSHA(body) + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: sha, Body: body, Kind: OperationKindQuery}}) require.NoError(t, err) mr.Server().SetPreHook(func(c *miniredisserver.Peer, cmd string, _ ...string) bool { @@ -259,6 +357,6 @@ func TestRedisBackendBundleWriteBackIsBestEffort(t *testing.T) { bundle, err := backend.Bundle(ctx, "session-1") require.NoError(t, err) - assert.Equal(t, "rendered:getUser", bundle) + assert.Equal(t, "rendered:"+sha, bundle) assert.Equal(t, false, mr.Exists(backend.bundleKey("session-1"))) } diff --git a/router/internal/codemode/storage/types.go b/router/internal/codemode/storage/types.go index ba3f1c7df2..6fa742eff9 100644 --- a/router/internal/codemode/storage/types.go +++ b/router/internal/codemode/storage/types.go @@ -8,8 +8,22 @@ const ( ) type SessionOp struct { - Name string - Body string - Kind OperationKind - Description string + // Name is the JS-side identifier exposed to user code as + // `tools.`. It is the ShortSHA() projection of the canonical + // body — content-derived, so two operations with the same body always + // share an identifier and two operations that yoko hands back under + // the same document name but with different bodies do not collide. + Name string + // Body is the GraphQL operation source text — exactly one named + // operation per the yoko proto contract. + Body string + Kind OperationKind + // DocumentName is the operation's name as it appears INSIDE Body + // (yoko's `operation_name` field). The host bridge passes this — not + // Name — as `operationName` when invoking the operation against + // /graphql, because the router's parser matches the document's + // literal operation name. Falls back to Name when empty (older + // sessions, tests that omit the field). + DocumentName string + Description string } From 3c5d33eadd78e44da7ab7c4a8d40aadec54cadac Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 8 May 2026 09:58:58 +0200 Subject: [PATCH 08/10] chore(code-mode-v2): demo full-federation fix + yoko proto/mock refresh Multi-topic snapshot of code-mode-v2 work in progress. demo/code-mode: - Include all non-EDFS demo subgraphs in the local federation: hobbies, products, test1, countries + the products_fg feature graph (under feature flag myff), in addition to the existing employees, family, availability, mood. Mirrors demo/graph-no-edg.yaml. The previous 4-subgraph set silently lacked Employee.hobbies (and Country, products, test1 fields), so code-mode-search-tools could never resolve prompts about those domains. - Fix BSD sed bug in prepare-schemas.sh: the previous regex used `\b` which macOS BSD sed treats as a literal `b`, so @authenticated stayed on type/enum/interface definitions. Switch to a portable (^|whitespace)/(non-word|EOL) anchor pair. Also extend the loop to cover all 9 subgraphs. - start.sh + run_subgraphs_subset.sh: launch the 5 added subgraphs on ports 4003/4004/4006/4009/4010 with matching wait_url checks. - README + Makefile comment: document the new subgraph set. yoko proto + mock + client: - Rewrite proto/wg/cosmo/code_mode/yoko/v1/yoko.proto and regenerate router/gen/proto/.../{yoko.pb.go, yokov1connect/yoko.connect.go}. - Rewrite demo/code-mode/yoko-mock/main.go (and tests). - Refresh router/internal/codemode/yoko/{client.go, client_test.go, searcher.go} against the new contract. - Bump demo/code-mode/{mcp-stdio-proxy,yoko-mock}/go.{mod,sum}. Connect demo: - Refresh demo/code-mode-connect/{README.md, router-config.yaml, start.sh} against the same yoko contract. Build / generate: - Makefile: pass --include-imports to `buf generate` so transitive protobuf imports are regenerated alongside the yoko v1 schema. - buf.yaml + buf.lock: pin the buf module deps now needed for that --include-imports run. - router/gen/proto/buf/validate/validate.pb.go: regenerated buf validate import. Misc: - router/pkg/codemode/varschema: new package + tests. - router/pkg/grpcconnector/grpcplugin/grpc_plugin.go: minor update. - router/internal/codemode/sandbox/sandbox_test.go + router/internal/codemode/server/observability_handler_test.go: adjust tests for the refreshed client surface. Co-Authored-By: Claude Opus 4.7 (1M context) --- Makefile | 23 +- buf.lock | 6 + buf.yaml | 2 + demo/code-mode-connect/README.md | 12 +- demo/code-mode-connect/router-config.yaml | 8 +- demo/code-mode-connect/start.sh | 43 +- demo/code-mode/README.md | 24 +- demo/code-mode/graph.yaml | 24 + demo/code-mode/mcp-stdio-proxy/go.mod | 9 +- demo/code-mode/mcp-stdio-proxy/go.sum | 12 +- demo/code-mode/prepare-schemas.sh | 11 +- demo/code-mode/router-config.yaml | 2 +- demo/code-mode/run_subgraphs_subset.sh | 7 +- demo/code-mode/start.sh | 40 +- demo/code-mode/yoko-mock/go.mod | 8 +- demo/code-mode/yoko-mock/go.sum | 30 +- demo/code-mode/yoko-mock/main.go | 282 +- demo/code-mode/yoko-mock/main_test.go | 123 +- proto/wg/cosmo/code_mode/yoko/v1/yoko.proto | 125 +- router/gen/proto/buf/validate/validate.pb.go | 9165 +++++++++++++++++ .../wg/cosmo/code_mode/yoko/v1/yoko.pb.go | 447 +- .../yoko/v1/yokov1connect/yoko.connect.go | 108 +- router/go.mod | 24 +- router/go.sum | 53 +- .../internal/codemode/sandbox/sandbox_test.go | 59 + .../server/observability_handler_test.go | 8 +- router/internal/codemode/yoko/client.go | 62 +- router/internal/codemode/yoko/client_test.go | 389 +- router/internal/codemode/yoko/searcher.go | 5 +- router/pkg/codemode/varschema/varschema.go | 329 + .../pkg/codemode/varschema/varschema_test.go | 79 + .../grpcconnector/grpcplugin/grpc_plugin.go | 6 +- 32 files changed, 10726 insertions(+), 799 deletions(-) create mode 100644 buf.lock create mode 100644 router/gen/proto/buf/validate/validate.pb.go create mode 100644 router/pkg/codemode/varschema/varschema.go create mode 100644 router/pkg/codemode/varschema/varschema_test.go diff --git a/Makefile b/Makefile index 12ec52c48c..c9606baf24 100644 --- a/Makefile +++ b/Makefile @@ -116,7 +116,7 @@ generate: make generate-go generate-go: - rm -rf router/gen && buf generate --path proto/wg/cosmo/node --path proto/wg/cosmo/common --path proto/wg/cosmo/graphqlmetrics --path proto/wg/cosmo/code_mode/yoko/v1 --template buf.router.go.gen.yaml + rm -rf router/gen && buf generate --path proto/wg/cosmo/node --path proto/wg/cosmo/common --path proto/wg/cosmo/graphqlmetrics --path proto/wg/cosmo/code_mode/yoko/v1 --include-imports --template buf.router.go.gen.yaml rm -rf graphqlmetrics/gen && buf generate --path proto/wg/cosmo/graphqlmetrics --path proto/wg/cosmo/common --template buf.graphqlmetrics.go.gen.yaml rm -rf connect-go/wg && buf generate --path proto/wg/cosmo/platform --path proto/wg/cosmo/notifications --path proto/wg/cosmo/common --path proto/wg/cosmo/node --template buf.connect-go.go.gen.yaml @@ -191,14 +191,16 @@ CODE_MODE_GOCACHE ?= /tmp/cosmo-code-mode-go-build-cache .PHONY: code-mode-demo code-mode-demo-down code-mode-connect-demo code-mode-connect-demo-down -# Local Code Mode demo: small federation (employees, family, availability, -# mood) + Yoko mock + Cosmo Router with Code Mode and named operations. -# Router GraphQL on :3002, MCP on :5027. Full instructions, prerequisites -# (codex CLI on PATH), and tear-down: demo/code-mode/README.md. +# Local Code Mode demo: federation of all non-EDFS demo subgraphs +# (employees, family, hobbies, products, test1, availability, mood, countries, +# plus the products_fg feature graph) + Cosmo Router with Code Mode and +# named operations. Router GraphQL on :3002, MCP on :5027. Yoko runs as a +# separate external service expected at http://127.0.0.1:3400 — start it +# before this target (override with YOKO_URL=...). Full instructions, +# prerequisites, and tear-down: demo/code-mode/README.md. code-mode-demo: mkdir -p $(CODE_MODE_GOCACHE) GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C router build - GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-yoko GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-stdio-proxy $(MAKE) -C demo/code-mode compose ./demo/code-mode/start.sh @@ -208,9 +210,11 @@ code-mode-demo-down: ./demo/code-mode/start.sh --down # Runs the code-mode router from source against the yoko Connect supergraph -# (plugins + composed config live in $(YOKO_DIR)). Uses different ports than -# code-mode-demo (router 3012, MCP 5037, yoko-mock 5038) so both can run at -# the same time. Set YOKO_DIR to your local yoko checkout, e.g. +# (plugins + composed config live in $(YOKO_DIR)). Uses different router/MCP +# ports than code-mode-demo (router 3012, MCP 5037) so both can run at the +# same time, and shares the same external yoko service expected at +# http://127.0.0.1:3400 (override with YOKO_URL=...). Set YOKO_DIR to your +# local yoko checkout, e.g. # `make code-mode-connect-demo YOKO_DIR=/path/to/yoko`. # Full instructions and prerequisites: demo/code-mode-connect/README.md. YOKO_DIR ?= @@ -219,7 +223,6 @@ code-mode-connect-demo: @if [ -z "$(YOKO_DIR)" ]; then echo "YOKO_DIR is required (path to your yoko checkout). See demo/code-mode-connect/README.md" >&2; exit 1; fi mkdir -p $(CODE_MODE_GOCACHE) GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C router build - GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-yoko YOKO_DIR=$(YOKO_DIR) ./demo/code-mode-connect/start.sh # Tear down anything left behind by code-mode-connect-demo. diff --git a/buf.lock b/buf.lock new file mode 100644 index 0000000000..709ae02396 --- /dev/null +++ b/buf.lock @@ -0,0 +1,6 @@ +# Generated by buf. DO NOT EDIT. +version: v2 +deps: + - name: buf.build/bufbuild/protovalidate + commit: 50325440f8f24053b047484a6bf60b76 + digest: b5:74cb6f5c0853c3c10aafc701614194bbd63326bdb8ef4068214454b8894b03ba4113e04b3a33a8321cdf05336e37db4dc14a5e2495db8462566914f36086ba31 diff --git a/buf.yaml b/buf.yaml index ee5c8b279b..ecbcf646a5 100644 --- a/buf.yaml +++ b/buf.yaml @@ -1,5 +1,7 @@ version: v2 modules: - path: proto +deps: + - buf.build/bufbuild/protovalidate lint: disallow_comment_ignores: true diff --git a/demo/code-mode-connect/README.md b/demo/code-mode-connect/README.md index cdc494e5ec..4176cb0ad4 100644 --- a/demo/code-mode-connect/README.md +++ b/demo/code-mode-connect/README.md @@ -3,7 +3,7 @@ This demo runs the Code Mode router against an external `yoko` Connect supergraph instead of the local employees federation used by `make code-mode-demo`. It is useful when you want to exercise Code Mode against a richer set of plugins (Pylon, Linear, PostHog, Circleback, Slack, Notion) served by the `yoko` project. -It is designed to coexist with `make code-mode-demo`: it uses different ports (router 3012, MCP 5037, yoko-mock 5038), so both demos can run side-by-side. +It is designed to coexist with `make code-mode-demo`: it uses different router/MCP ports (router 3012, MCP 5037), and both demos share the same external Yoko service at `http://127.0.0.1:3400` (override with `YOKO_URL=...`). ## Prerequisites @@ -12,7 +12,8 @@ It is designed to coexist with `make code-mode-demo`: it uses different ports (r - `config.json` — the composed router config for the yoko supergraph. - `plugins/` — the plugin binaries the router will load. - Go (toolchain matching the repo `go.mod`). -- The `codex` CLI on `PATH`, authenticated. The Yoko mock shells out to `codex` for query generation. +- A running Yoko service reachable at `http://127.0.0.1:3400` (override with `YOKO_URL=...`). + The router calls Yoko for query generation; without it, `code_mode_search_tools` cannot generate operations. ## Run @@ -28,16 +29,15 @@ The target fails fast with a clear error if it is missing or if the directory do What the target does: 1. Builds `router/router`. -2. Builds `demo/code-mode/yoko-mock/yoko-mock`. -3. Starts `yoko-mock` on `localhost:5038`. -4. Starts the router with `YOKO_DIR` as its working directory and `demo/code-mode-connect/router-config.yaml` as its config. +2. Health-checks the external Yoko service at `$YOKO_URL/health` (default `http://127.0.0.1:3400`). +3. Starts the router with `YOKO_DIR` as its working directory and `demo/code-mode-connect/router-config.yaml` as its config. The router resolves `config.json` and `plugins/` relative to that CWD, which is why `YOKO_DIR` must be a real composed yoko checkout. Expected ports: - Router GraphQL: `http://localhost:3012/graphql` - Code Mode MCP: `http://127.0.0.1:5037/mcp` -- Yoko mock: `http://localhost:5038` +- Yoko (external): `http://127.0.0.1:3400` ## Tearing down diff --git a/demo/code-mode-connect/router-config.yaml b/demo/code-mode-connect/router-config.yaml index b2e102fe00..a4f2b8b8b6 100644 --- a/demo/code-mode-connect/router-config.yaml +++ b/demo/code-mode-connect/router-config.yaml @@ -1,8 +1,8 @@ version: "1" -# Different ports than demo/code-mode/router-config.yaml so both demos can run -# side-by-side. See demo/code-mode-connect/start.sh for the matching yoko-mock -# port. +# Different router/MCP ports than demo/code-mode/router-config.yaml so both +# demos can run side-by-side. They share the same external yoko service +# (http://127.0.0.1:3400) — start.sh no longer launches a local yoko-mock. listen_addr: "localhost:3012" graphql_path: "/graphql" playground_enabled: false @@ -75,7 +75,7 @@ mcp: timeout: 180s query_generation: enabled: true - endpoint: http://localhost:5038 + endpoint: http://127.0.0.1:3400 timeout: 180s execute_timeout: 180s named_ops: diff --git a/demo/code-mode-connect/start.sh b/demo/code-mode-connect/start.sh index 7379fa85d3..3a96f1a310 100755 --- a/demo/code-mode-connect/start.sh +++ b/demo/code-mode-connect/start.sh @@ -15,7 +15,10 @@ YOKO_DIR="${YOKO_DIR:?YOKO_DIR is required (path to your yoko checkout)}" ROUTER_BIN="$ROOT_DIR/router/router" ROUTER_CONFIG="$CONNECT_DIR/router-config.yaml" -YOKO_BIN="$DEMO_DIR/code-mode/yoko-mock/yoko-mock" + +# Yoko is a separate service expected at http://127.0.0.1:3400. start.sh no +# longer launches a local mock — bring up your real yoko service before running. +YOKO_URL="${YOKO_URL:-http://127.0.0.1:3400}" append_pid() { local name="$1" @@ -57,6 +60,12 @@ cleanup() { exit "$status" } +on_signal() { + trap - EXIT INT TERM + kill_pid_file + exit 0 +} + wait_url() { local name="$1" local url="$2" @@ -112,12 +121,6 @@ if [ ! -x "$ROUTER_BIN" ]; then exit 1 fi -if [ ! -x "$YOKO_BIN" ]; then - echo "Yoko mock binary not found or not executable: $YOKO_BIN" >&2 - echo "Run: make -C demo/code-mode build-yoko" >&2 - exit 1 -fi - if [ ! -f "$YOKO_DIR/config.json" ]; then echo "Composed yoko supergraph not found: $YOKO_DIR/config.json" >&2 echo "Run: cd $YOKO_DIR && make compose" >&2 @@ -127,19 +130,29 @@ fi mkdir -p "$LOG_DIR" mkdir -p "$GOCACHE_DIR" rm -f "$PID_FILE" -trap cleanup EXIT INT TERM - -# yoko-mock listens on a different port than the regular code-mode-demo so the -# two demos can coexist (5028 vs 5038). -start_background_root yoko "$YOKO_BIN" -listen-addr localhost:5038 - -wait_url yoko http://localhost:5038/health +trap cleanup EXIT +trap on_signal INT TERM + +# Verify the external yoko service is reachable. We don't probe a specific +# path because the real service doesn't necessarily expose /health — just +# confirm the TCP/HTTP socket accepts a connection. Any HTTP response (200, +# 404, 405 …) means the server is up; only a connection failure aborts. +# Override with YOKO_URL when yoko runs at a different address. +if ! curl -sS -o /dev/null --max-time 3 "$YOKO_URL" >/dev/null 2>&1; then + echo "Yoko service is not reachable at $YOKO_URL" >&2 + echo "Start your yoko service (or set YOKO_URL=...) before running this demo." >&2 + exit 1 +fi +echo "yoko is ready at $YOKO_URL" echo "Starting router in foreground (CWD=$YOKO_DIR)" +echo "Router output is being teed to $LOG_DIR/router.log" +# Tee stdout+stderr so the user still sees live output AND we keep a persistent +# log for post-mortem debugging when the router exits unexpectedly. ( cd "$YOKO_DIR" exec "$ROUTER_BIN" -config "$ROUTER_CONFIG" -) & +) 2>&1 | tee "$LOG_DIR/router.log" & router_pid="$!" append_pid router "$router_pid" diff --git a/demo/code-mode/README.md b/demo/code-mode/README.md index dee17d14a2..9710ef72b1 100644 --- a/demo/code-mode/README.md +++ b/demo/code-mode/README.md @@ -1,14 +1,15 @@ # Code Mode Demo -This demo starts a small local federation (`employees`, `family`, `availability`, and `mood`), the Code Mode Yoko mock, and a local Cosmo Router with Code Mode and named operations enabled. +This demo starts a local federation of all non-EDFS demo subgraphs (`employees`, `family`, `hobbies`, `products`, `test1`, `availability`, `mood`, `countries`, plus the `products_fg` feature graph under feature flag `myff`) and a local Cosmo Router with Code Mode and named operations enabled. The router talks to an external Yoko service for query generation — start that separately before running the demo. + +The set mirrors `demo/graph-no-edg.yaml`. The `employeeupdated` subgraph is intentionally excluded because it relies on EDFS (NATS) streams. ## Prerequisites - Go (toolchain matching the repo `go.mod`). - Node + `pnpm` (used by `wgc` to compose `demo/code-mode/graph.yaml`). -- The `codex` CLI on `PATH`, authenticated. - The Yoko mock shells out to `codex` for query generation; - without it, `code_mode_search_tools` cannot generate operations. +- A running Yoko service reachable at `http://127.0.0.1:3400` (override with `YOKO_URL=...`). + The router calls Yoko for `code_mode_search_tools`; without it, query generation will fail. ## Quick start @@ -18,23 +19,28 @@ Run it from the repository root: make code-mode-demo ``` -The root target builds `router/router`, builds `demo/code-mode/yoko-mock/yoko-mock`, builds `demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy` (used by stdio-only MCP clients like Claude Desktop), composes `demo/code-mode/graph.yaml` into `demo/code-mode/config.json`, then starts the demo processes. -The router stays in the foreground. +The root target builds `router/router`, builds `demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy` (used by stdio-only MCP clients like Claude Desktop), composes `demo/code-mode/graph.yaml` into `demo/code-mode/config.json`, then starts the demo processes. +The router stays in the foreground. `start.sh` health-checks the external Yoko service before the router starts. Expected ports: - Router GraphQL: `http://localhost:3002/graphql` - Code Mode MCP: `http://localhost:5027/mcp` -- Yoko mock: `http://localhost:5028` +- Yoko (external): `http://127.0.0.1:3400` - Employees subgraph: `http://localhost:4001/graphql` - Family subgraph: `http://localhost:4002/graphql` +- Hobbies subgraph: `http://localhost:4003/graphql` +- Products subgraph: `http://localhost:4004/graphql` +- Test1 subgraph: `http://localhost:4006/graphql` - Availability subgraph: `http://localhost:4007/graphql` - Mood subgraph: `http://localhost:4008/graphql` +- Countries subgraph: `http://localhost:4009/graphql` +- Products_fg feature graph: `http://localhost:4010/graphql` ## Tearing down To stop the demo, press Ctrl-C in the foreground terminal. -If anything is left behind (background subgraphs, yoko-mock), run: +If anything is left behind (background subgraphs), run: ```sh make code-mode-demo-down @@ -53,7 +59,7 @@ curl -sS http://localhost:3002/graphql \ ## Other notes -The subset runner is `demo/code-mode/run_subgraphs_subset.sh`. It starts only `employees`, `family`, `availability`, and `mood` via `npx concurrently` for a fast demo. `availability` and `mood` are included because the `employees` schema has federation references to fields owned by those subgraphs. The full demo `demo/run_subgraphs.sh` starts all subgraphs and is intentionally not used here. +The subset runner is `demo/code-mode/run_subgraphs_subset.sh`. It starts every non-EDFS subgraph used by this demo (`employees`, `family`, `hobbies`, `products`, `test1`, `availability`, `mood`, `countries`, `products_fg`) via `npx concurrently`. The full demo `demo/run_subgraphs.sh` additionally starts the EDFS-dependent `employeeupdated` subgraph and is intentionally not used here. Client configuration for Code Mode MCP clients (Claude Code, Claude Desktop, Codex CLI) lives under `demo/code-mode/mcp-configs/` — see the README there. diff --git a/demo/code-mode/graph.yaml b/demo/code-mode/graph.yaml index c67180f295..2cd374a1f5 100644 --- a/demo/code-mode/graph.yaml +++ b/demo/code-mode/graph.yaml @@ -1,4 +1,12 @@ version: 1 +feature_flags: + - name: myff + feature_graphs: + - name: products_fg + subgraph_name: products + routing_url: http://localhost:4010/graphql + schema: + file: schemas/products_fg.graphqls subgraphs: - name: employees routing_url: http://localhost:4001/graphql @@ -8,6 +16,18 @@ subgraphs: routing_url: http://localhost:4002/graphql schema: file: schemas/family.graphqls + - name: hobbies + routing_url: http://localhost:4003/graphql + schema: + file: schemas/hobbies.graphqls + - name: products + routing_url: http://localhost:4004/graphql + schema: + file: schemas/products.graphqls + - name: test1 + routing_url: http://localhost:4006/graphql + schema: + file: schemas/test1.graphqls - name: availability routing_url: http://localhost:4007/graphql schema: @@ -16,3 +36,7 @@ subgraphs: routing_url: http://localhost:4008/graphql schema: file: schemas/mood.graphqls + - name: countries + routing_url: http://localhost:4009/graphql + schema: + file: schemas/countries.graphqls diff --git a/demo/code-mode/mcp-stdio-proxy/go.mod b/demo/code-mode/mcp-stdio-proxy/go.mod index 4720b36f3c..67e4d67f22 100644 --- a/demo/code-mode/mcp-stdio-proxy/go.mod +++ b/demo/code-mode/mcp-stdio-proxy/go.mod @@ -1,6 +1,6 @@ module github.com/wundergraph/cosmo/demo/code-mode/mcp-stdio-proxy -go 1.25 +go 1.25.0 require ( github.com/modelcontextprotocol/go-sdk v1.4.1 @@ -8,13 +8,16 @@ require ( ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/google/jsonschema-go v0.4.2 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/segmentio/asm v1.1.3 // indirect github.com/segmentio/encoding v0.5.4 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sys v0.40.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/demo/code-mode/mcp-stdio-proxy/go.sum b/demo/code-mode/mcp-stdio-proxy/go.sum index e469bb22cf..f2ce6f5233 100644 --- a/demo/code-mode/mcp-stdio-proxy/go.sum +++ b/demo/code-mode/mcp-stdio-proxy/go.sum @@ -1,15 +1,17 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc= github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= @@ -24,7 +26,7 @@ golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demo/code-mode/prepare-schemas.sh b/demo/code-mode/prepare-schemas.sh index 92fd8f6dae..017539d589 100755 --- a/demo/code-mode/prepare-schemas.sh +++ b/demo/code-mode/prepare-schemas.sh @@ -21,14 +21,17 @@ strip_auth() { local out="$2" # Remove @requiresScopes(scopes: [[...], [...]]) — match the doubly-nested # bracket payload, then drop @authenticated standalone uses. The directive - # imports inside @link(import: [...]) stay (they're string literals, not - # directive applications, so they don't trigger enforcement). + # imports inside @link(import: [...]) are quoted string literals, so the + # @authenticated rule requires a non-quote, non-word predecessor (start of + # line or whitespace) to leave those imports intact. + # + # POSIX-portable: avoid \b (BSD sed treats it as a literal `b`). sed -E ' s/[[:space:]]*@requiresScopes\(scopes:[[:space:]]*\[(\[[^][]*\][, ]*)+\]\)//g - s/[[:space:]]*@authenticated\b//g + s/(^|[[:space:]])@authenticated([^a-zA-Z0-9_]|$)/\1\2/g ' "$in" > "$out" } -for sg in employees family availability mood; do +for sg in employees family availability mood hobbies products test1 countries products_fg; do strip_auth "$SRC_DIR/$sg/subgraph/schema.graphqls" "$OUT_DIR/$sg.graphqls" done diff --git a/demo/code-mode/router-config.yaml b/demo/code-mode/router-config.yaml index 01fac390a2..3c676db755 100644 --- a/demo/code-mode/router-config.yaml +++ b/demo/code-mode/router-config.yaml @@ -48,7 +48,7 @@ mcp: timeout: 180s query_generation: enabled: true - endpoint: http://localhost:5028 + endpoint: http://127.0.0.1:3400 timeout: 180s execute_timeout: 180s named_ops: diff --git a/demo/code-mode/run_subgraphs_subset.sh b/demo/code-mode/run_subgraphs_subset.sh index 23e2c20ec3..d393e551b9 100755 --- a/demo/code-mode/run_subgraphs_subset.sh +++ b/demo/code-mode/run_subgraphs_subset.sh @@ -9,5 +9,10 @@ mkdir -p "$GOCACHE" npx concurrently --kill-others \ "GOCACHE=$GOCACHE PORT=4001 go run ./cmd/employees" \ "GOCACHE=$GOCACHE PORT=4002 go run ./cmd/family" \ + "GOCACHE=$GOCACHE PORT=4003 go run ./cmd/hobbies" \ + "GOCACHE=$GOCACHE PORT=4004 go run ./cmd/products" \ + "GOCACHE=$GOCACHE PORT=4006 go run ./cmd/test1" \ "GOCACHE=$GOCACHE PORT=4007 go run ./cmd/availability" \ - "GOCACHE=$GOCACHE PORT=4008 go run ./cmd/mood" + "GOCACHE=$GOCACHE PORT=4008 go run ./cmd/mood" \ + "GOCACHE=$GOCACHE PORT=4009 go run ./cmd/countries" \ + "GOCACHE=$GOCACHE PORT=4010 go run ./cmd/products_fg" diff --git a/demo/code-mode/start.sh b/demo/code-mode/start.sh index c079e1d1db..ab3e353602 100755 --- a/demo/code-mode/start.sh +++ b/demo/code-mode/start.sh @@ -11,7 +11,10 @@ GOCACHE_DIR="${GOCACHE:-/tmp/cosmo-code-mode-go-build-cache}" ROUTER_BIN="$ROOT_DIR/router/router" ROUTER_CONFIG="$CODE_MODE_DIR/router-config.yaml" -YOKO_BIN="$CODE_MODE_DIR/yoko-mock/yoko-mock" + +# Yoko is a separate service expected at http://127.0.0.1:3400. start.sh no +# longer launches a local mock — bring up your real yoko service before running. +YOKO_URL="${YOKO_URL:-http://127.0.0.1:3400}" append_pid() { local name="$1" @@ -112,12 +115,6 @@ if [ ! -x "$ROUTER_BIN" ]; then exit 1 fi -if [ ! -x "$YOKO_BIN" ]; then - echo "Yoko mock binary not found or not executable: $YOKO_BIN" >&2 - echo "Run: cd demo/code-mode/yoko-mock && go build -o yoko-mock ." >&2 - exit 1 -fi - if [ ! -f "$CODE_MODE_DIR/config.json" ]; then echo "Composed router config not found: $CODE_MODE_DIR/config.json" >&2 echo "Run: make -C demo/code-mode compose" >&2 @@ -131,18 +128,41 @@ trap cleanup EXIT INT TERM start_background employees "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4001 go run ./cmd/employees start_background family "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4002 go run ./cmd/family +start_background hobbies "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4003 go run ./cmd/hobbies +start_background products "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4004 go run ./cmd/products +start_background test1 "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4006 go run ./cmd/test1 start_background availability "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4007 go run ./cmd/availability start_background mood "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4008 go run ./cmd/mood -start_background_root yoko "$YOKO_BIN" -listen-addr localhost:5028 +start_background countries "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4009 go run ./cmd/countries +start_background products_fg "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4010 go run ./cmd/products_fg wait_url employees http://localhost:4001/ wait_url family http://localhost:4002/ +wait_url hobbies http://localhost:4003/ +wait_url products http://localhost:4004/ +wait_url test1 http://localhost:4006/ wait_url availability http://localhost:4007/ wait_url mood http://localhost:4008/ -wait_url yoko http://localhost:5028/health +wait_url countries http://localhost:4009/ +wait_url products_fg http://localhost:4010/ + +# Verify the external yoko service is reachable. We don't probe a specific +# path because the real service doesn't necessarily expose /health — just +# confirm the TCP/HTTP socket accepts a connection. Any HTTP response (200, +# 404, 405 …) means the server is up; only a connection failure aborts. +# Override with YOKO_URL when yoko runs at a different address. +if ! curl -sS -o /dev/null --max-time 3 "$YOKO_URL" >/dev/null 2>&1; then + echo "Yoko service is not reachable at $YOKO_URL" >&2 + echo "Start your yoko service (or set YOKO_URL=...) before running this demo." >&2 + exit 1 +fi +echo "yoko is ready at $YOKO_URL" echo "Starting router in foreground" -"$ROUTER_BIN" -config "$ROUTER_CONFIG" & +echo "Router output is being teed to $LOG_DIR/router.log" +# Tee stdout+stderr so the user still sees live output AND we keep a persistent +# log for post-mortem debugging when the router exits unexpectedly. +"$ROUTER_BIN" -config "$ROUTER_CONFIG" 2>&1 | tee "$LOG_DIR/router.log" & router_pid="$!" append_pid router "$router_pid" diff --git a/demo/code-mode/yoko-mock/go.mod b/demo/code-mode/yoko-mock/go.mod index 807baae723..ea0ceff1d5 100644 --- a/demo/code-mode/yoko-mock/go.mod +++ b/demo/code-mode/yoko-mock/go.mod @@ -7,14 +7,18 @@ require ( github.com/dgraph-io/ristretto/v2 v2.4.0 github.com/stretchr/testify v1.11.1 github.com/wundergraph/cosmo/router v0.0.0 + github.com/wundergraph/graphql-go-tools/v2 v2.1.0 google.golang.org/protobuf v1.36.10 ) require ( + github.com/buger/jsonparser v1.1.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/wundergraph/go-arena v1.1.0 // indirect golang.org/x/sys v0.40.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/demo/code-mode/yoko-mock/go.sum b/demo/code-mode/yoko-mock/go.sum index e60cb737e8..473aa9fbc3 100644 --- a/demo/code-mode/yoko-mock/go.sum +++ b/demo/code-mode/yoko-mock/go.sum @@ -1,9 +1,12 @@ connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= +github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk= +github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/ristretto/v2 v2.4.0 h1:I/w09yLjhdcVD2QV192UJcq8dPBaAJb9pOuMyNy0XlU= github.com/dgraph-io/ristretto/v2 v2.4.0/go.mod h1:0KsrXtXvnv0EqnzyowllbVJB8yBonswa2lTCK2gGo9E= github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= @@ -12,15 +15,32 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/jensneuse/diffview v1.0.0 h1:4b6FQJ7y3295JUHU3tRko6euyEboL825ZsXeZZM47Z4= +github.com/jensneuse/diffview v1.0.0/go.mod h1:i6IacuD8LnEaPuiyzMHA+Wfz5mAuycMOf3R/orUY9y4= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sebdah/goldie/v2 v2.7.1 h1:PkBHymaYdtvEkZV7TmyqKxdmn5/Vcj+8TpATWZjnG5E= +github.com/sebdah/goldie/v2 v2.7.1/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= +github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= +github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc= +github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= +github.com/wundergraph/graphql-go-tools/v2 v2.1.0 h1:V1MU/uo+oc5b+aIh3SpCr0rJgLHuhonWg2fhN1sfMdY= +github.com/wundergraph/graphql-go-tools/v2 v2.1.0/go.mod h1:UG/grnPEHumtD82H8FC+3dokiCGK8GF0b5IJc00lSbM= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demo/code-mode/yoko-mock/main.go b/demo/code-mode/yoko-mock/main.go index 3a412fe48f..58e86ca16d 100644 --- a/demo/code-mode/yoko-mock/main.go +++ b/demo/code-mode/yoko-mock/main.go @@ -24,6 +24,10 @@ import ( "github.com/dgraph-io/ristretto/v2" yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" + "github.com/wundergraph/cosmo/router/pkg/codemode/varschema" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" ) const badOutputPath = "/tmp/yoko-mock-last-bad-output.log" @@ -32,45 +36,47 @@ type yokoService struct { codexBin string codexTimeout time.Duration codexReasoningEffort string - rotateAfter int // re-warm the codex session after this many Search calls; 0 disables + rotateAfter int // re-warm the codex session after this many GenerateQuery calls; 0 disables - // promptCache memoizes (schemaID, prompt) -> GeneratedOperation. A cache - // hit lets us skip codex entirely for that prompt. nil if the cache is + // promptCache memoizes (schemaID, prompt) -> ResolvedQuery. A cache hit + // lets us skip codex entirely for that prompt. nil if the cache is // disabled (size <= 0). - promptCache *ristretto.Cache[string, *yokov1.GeneratedOperation] + promptCache *ristretto.Cache[string, *yokov1.ResolvedQuery] mu sync.RWMutex schemas map[string]*schemaEntry } // schemaEntry records the on-disk schema dir (so codex can read schema.graphql -// once at Index time) plus the codex session id created during that pre-warm. -// Search uses `codex exec resume ` to reuse the already-loaded -// schema context instead of re-reading it on every call. -// -// To bound session-file growth, every yokoService.rotateAfter Search calls a -// background goroutine pre-warms a fresh session and atomically swaps the -// sessionID. searchCount tracks calls; rotationActive ensures only one -// rotation runs at a time. +// once at IndexSchema time) plus the codex session id created during that +// pre-warm and the parsed schema document used to derive variables_schema for +// each generated operation. type schemaEntry struct { - dir string + dir string + schema *ast.Document mu sync.RWMutex sessionID string - searchCount atomic.Int64 + generateCount atomic.Int64 rotationActive atomic.Bool } -type codexOperation struct { - Name string `json:"name"` - Body string `json:"body"` - Kind string `json:"kind"` - Description string `json:"description"` +type codexResolvedQuery struct { + Description string `json:"description"` + Document string `json:"document"` + OperationName string `json:"operation_name"` + OperationType string `json:"operation_type"` } -type codexOutput struct { - Operations []codexOperation `json:"operations"` +type codexUnsatisfied struct { + Reason string `json:"reason"` +} + +type codexResolution struct { + Queries []codexResolvedQuery `json:"queries"` + Unsatisfied []codexUnsatisfied `json:"unsatisfied"` + Truncated bool `json:"truncated"` } func main() { @@ -78,8 +84,8 @@ func main() { codexBin := flag.String("codex-bin", "codex", "codex CLI binary path or name") codexTimeout := flag.Duration("codex-timeout", 60*time.Second, "codex CLI timeout") codexReasoningEffort := flag.String("codex-reasoning-effort", "low", "codex reasoning effort: minimal | low | medium | high") - codexRotateAfter := flag.Int("codex-rotate-after", 20, "re-warm the codex session after N Search calls (0 = disable rotation)") - promptCacheSize := flag.Int("prompt-cache-size", 1000, "max items in the (schema_id, prompt) -> operation cache (0 = disable)") + codexRotateAfter := flag.Int("codex-rotate-after", 20, "re-warm the codex session after N GenerateQuery calls (0 = disable rotation)") + promptCacheSize := flag.Int("prompt-cache-size", 1000, "max items in the (schema_id, prompt) -> resolved_query cache (0 = disable)") flag.Parse() svc, err := newYokoService(*codexBin, *codexTimeout, *codexReasoningEffort, *codexRotateAfter, *promptCacheSize) @@ -126,7 +132,7 @@ func newYokoService(codexBin string, codexTimeout time.Duration, reasoningEffort if promptCacheSize > 0 { // Each cache entry has cost 1, so MaxCost is the item ceiling. // NumCounters is conventionally 10× expected items. - cache, err := ristretto.NewCache(&ristretto.Config[string, *yokov1.GeneratedOperation]{ + cache, err := ristretto.NewCache(&ristretto.Config[string, *yokov1.ResolvedQuery]{ NumCounters: int64(promptCacheSize) * 10, MaxCost: int64(promptCacheSize), BufferItems: 64, @@ -150,9 +156,9 @@ func newHTTPMux(svc *yokoService) *http.ServeMux { return mux } -func (s *yokoService) Index(ctx context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { - schemaSDL := req.Msg.GetSchemaSdl() - id := schemaID(schemaSDL) +func (s *yokoService) IndexSchema(ctx context.Context, req *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { + sdl := req.Msg.GetSdl() + id := schemaID(sdl) s.mu.Lock() if existing, ok := s.schemas[id]; ok { @@ -160,16 +166,21 @@ func (s *yokoService) Index(ctx context.Context, req *connect.Request[yokov1.Ind existing.mu.RLock() existingSession := existing.sessionID existing.mu.RUnlock() - log.Printf("Index schema_id=%s reused dir=%s session_id=%s", id, existing.dir, existingSession) - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + log.Printf("IndexSchema schema_id=%s reused dir=%s session_id=%s", id, existing.dir, existingSession) + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: id}), nil } s.mu.Unlock() + schemaDoc, err := parseSchemaSDL(sdl) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("parse schema SDL: %w", err)) + } + dir, err := os.MkdirTemp("", "yoko-schema-"+id+"-") if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create schema temp dir: %w", err)) } - if err := os.WriteFile(filepath.Join(dir, "schema.graphql"), []byte(schemaSDL), 0o600); err != nil { + if err := os.WriteFile(filepath.Join(dir, "schema.graphql"), []byte(sdl), 0o600); err != nil { _ = os.RemoveAll(dir) return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write schema.graphql: %w", err)) } @@ -180,16 +191,16 @@ func (s *yokoService) Index(ctx context.Context, req *connect.Request[yokov1.Ind return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("codex pre-warm: %w", err)) } - entry := &schemaEntry{dir: dir, sessionID: sessionID} + entry := &schemaEntry{dir: dir, schema: schemaDoc, sessionID: sessionID} s.mu.Lock() s.schemas[id] = entry s.mu.Unlock() - log.Printf("Index schema_id=%s schema_sdl_size=%d schema_dir=%s session_id=%s rotate_after=%d", id, len(schemaSDL), dir, sessionID, s.rotateAfter) - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + log.Printf("IndexSchema schema_id=%s sdl_size=%d schema_dir=%s session_id=%s rotate_after=%d", id, len(sdl), dir, sessionID, s.rotateAfter) + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: id}), nil } -// Close removes every per-schema temp dir created by Index. Safe to call +// Close removes every per-schema temp dir created by IndexSchema. Safe to call // multiple times; subsequent calls are no-ops. Codex session rollout files // live under ~/.codex/sessions/ and are intentionally left in place — they // belong to the user's codex install. @@ -210,124 +221,101 @@ func (s *yokoService) Close() { } } -func (s *yokoService) Search(ctx context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { +func (s *yokoService) GenerateQuery(ctx context.Context, req *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) { schemaID := req.Msg.GetSchemaId() - prompts := req.Msg.GetPrompts() + prompt := req.Msg.GetPrompt() s.mu.RLock() entry, ok := s.schemas[schemaID] s.mu.RUnlock() if !ok { - log.Printf("Search schema_id=%s prompt_count=%d not_found=true", schemaID, len(prompts)) - return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("schema_id %q not found; call Index before Search", schemaID)) + log.Printf("GenerateQuery schema_id=%s not_found=true", schemaID) + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("schema_id %q not found; call IndexSchema before GenerateQuery", schemaID)) } // Bump per-session call counter; if we crossed the threshold and no // rotation is in flight, kick one off in the background. The CAS makes // the trigger one-shot until rotation completes and clears the flag. - count := entry.searchCount.Add(1) + count := entry.generateCount.Add(1) if s.rotateAfter > 0 && count >= int64(s.rotateAfter) && entry.rotationActive.CompareAndSwap(false, true) { go s.rotateSession(schemaID, entry, count) } - // Cache lookup: collect cached ops in their original positions, batch - // only the misses to codex. - results := make([]*yokov1.GeneratedOperation, len(prompts)) - missing := make([]string, 0, len(prompts)) - missingIdx := make([]int, 0, len(prompts)) - hits := 0 - for i, p := range prompts { - if op, ok := s.cacheGet(schemaID, p); ok { - results[i] = op - hits++ - } else { - missing = append(missing, p) - missingIdx = append(missingIdx, i) - } - } - - if len(missing) == 0 { - log.Printf("Search schema_id=%s prompt_count=%d cache_hits=%d cache_misses=0 codex_skipped=true", schemaID, len(prompts), hits) - return connect.NewResponse(&yokov1.SearchResponse{Operations: filterNonNil(results)}), nil + if cached, ok := s.cacheGet(schemaID, prompt); ok { + log.Printf("GenerateQuery schema_id=%s cache_hit=true codex_skipped=true", schemaID) + return connect.NewResponse(&yokov1.GenerateQueryResponse{ + Resolution: &yokov1.Resolution{Queries: []*yokov1.ResolvedQuery{cached}}, + }), nil } entry.mu.RLock() sessionID := entry.sessionID entry.mu.RUnlock() - prompt := buildCodexPrompt(missing) - stdout, err := s.runCodexResume(ctx, sessionID, prompt) + codexPrompt := buildCodexPrompt(prompt) + stdout, err := s.runCodexResume(ctx, sessionID, codexPrompt) if err != nil { return nil, connect.NewError(connect.CodeInternal, err) } - generated, err := parseCodexOperations(stdout) + resolution, err := parseCodexResolution(stdout) if err != nil { if writeErr := os.WriteFile(badOutputPath, stdout, 0o600); writeErr != nil { log.Printf("warning: failed to write bad codex output path=%s err=%v", badOutputPath, writeErr) } - log.Printf("warning: codex output was not valid JSON schema_id=%s prompt_count=%d stdout_size=%d err=%v", schemaID, len(missing), len(stdout), err) + log.Printf("warning: codex output was not valid JSON schema_id=%s stdout_size=%d err=%v", schemaID, len(stdout), err) return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("codex output was not valid JSON; raw output saved to %s", badOutputPath)) } - // Pair generated ops back into the original prompt slots and cache the - // successful ones. We trust order: codex was instructed to return one - // operation per missing prompt in the same order. If codex returned - // fewer ops than asked, the trailing prompts have no slot filled (and - // don't get cached). - for k, idx := range missingIdx { - if k >= len(generated) { - break - } - op := generated[k] - if op == nil || op.GetBody() == "" { - // Failed prompt — don't cache, leave slot nil (filtered out below). + for _, q := range resolution.GetQueries() { + // Derive variables_schema statically from the parsed schema. If + // derivation fails we leave variables_schema empty so the client + // still gets a usable response — the agent can validate manually. + varsSchema, derr := varschema.ForOperation(q.GetDocument(), entry.schema) + if derr != nil { + log.Printf("warning: derive variables_schema schema_id=%s op=%q err=%v", schemaID, q.GetOperationName(), derr) continue } - results[idx] = op - s.cachePut(schemaID, missing[k], op) + q.VariablesSchema = varsSchema } - log.Printf("Search schema_id=%s prompt_count=%d cache_hits=%d cache_misses=%d codex_stdout_size=%d parsed_op_count=%d", schemaID, len(prompts), hits, len(missing), len(stdout), len(generated)) - return connect.NewResponse(&yokov1.SearchResponse{Operations: filterNonNil(results)}), nil -} - -func filterNonNil(ops []*yokov1.GeneratedOperation) []*yokov1.GeneratedOperation { - out := ops[:0] - for _, op := range ops { - if op != nil { - out = append(out, op) - } + // Cache successful single-query resolutions only — caching multi-query + // or unsatisfied resolutions would hide real codex variation. + if len(resolution.GetQueries()) == 1 && len(resolution.GetUnsatisfied()) == 0 && !resolution.GetTruncated() { + s.cachePut(schemaID, prompt, resolution.GetQueries()[0]) } - return out + + log.Printf("GenerateQuery schema_id=%s codex_stdout_size=%d query_count=%d unsatisfied_count=%d truncated=%v", + schemaID, len(stdout), len(resolution.GetQueries()), len(resolution.GetUnsatisfied()), resolution.GetTruncated()) + return connect.NewResponse(&yokov1.GenerateQueryResponse{Resolution: resolution}), nil } // cacheKey returns the (schema_id, prompt) lookup key. We include schema_id // so the same prompt against a different supergraph doesn't return a stale -// operation. +// query. func cacheKey(schemaID, prompt string) string { return schemaID + "\x00" + prompt } -func (s *yokoService) cacheGet(schemaID, prompt string) (*yokov1.GeneratedOperation, bool) { +func (s *yokoService) cacheGet(schemaID, prompt string) (*yokov1.ResolvedQuery, bool) { if s.promptCache == nil { return nil, false } return s.promptCache.Get(cacheKey(schemaID, prompt)) } -func (s *yokoService) cachePut(schemaID, prompt string, op *yokov1.GeneratedOperation) { +func (s *yokoService) cachePut(schemaID, prompt string, q *yokov1.ResolvedQuery) { if s.promptCache == nil { return } - s.promptCache.Set(cacheKey(schemaID, prompt), op, 1) + s.promptCache.Set(cacheKey(schemaID, prompt), q, 1) } -// rotateSession is launched in a goroutine when Search counts cross +// rotateSession is launched in a goroutine when GenerateQuery counts cross // rotateAfter. It pre-warms a fresh codex session against the same on-disk -// schema, then atomically swaps in the new sessionID and resets the search -// counter. While rotation is running, concurrent Search calls keep using the -// old sessionID — they just don't trigger a second rotation. +// schema, then atomically swaps in the new sessionID and resets the call +// counter. While rotation is running, concurrent calls keep using the old +// sessionID — they just don't trigger a second rotation. func (s *yokoService) rotateSession(schemaID string, entry *schemaEntry, triggerCount int64) { start := time.Now() log.Printf("rotation kickoff schema_id=%s trigger_count=%d", schemaID, triggerCount) @@ -347,19 +335,30 @@ func (s *yokoService) rotateSession(schemaID string, entry *schemaEntry, trigger entry.sessionID = newSessionID entry.mu.Unlock() - // Reset count BEFORE clearing rotationActive so a Search arriving in this + // Reset count BEFORE clearing rotationActive so a call arriving in this // gap can't trigger a second rotation on a freshly-rotated session. - entry.searchCount.Store(0) + entry.generateCount.Store(0) entry.rotationActive.Store(false) log.Printf("rotation complete schema_id=%s old_session=%s new_session=%s elapsed=%s", schemaID, oldSessionID, newSessionID, time.Since(start).Round(time.Millisecond)) } -func schemaID(schemaSDL string) string { - sum := sha256.Sum256([]byte(schemaSDL)) +func schemaID(sdl string) string { + sum := sha256.Sum256([]byte(sdl)) return fmt.Sprintf("%x", sum)[:16] } +func parseSchemaSDL(sdl string) (*ast.Document, error) { + doc, report := astparser.ParseGraphqlDocumentString(sdl) + if report.HasErrors() { + return nil, fmt.Errorf("parse SDL: %s", report.Error()) + } + if err := asttransform.MergeDefinitionWithBaseSchema(&doc); err != nil { + return nil, fmt.Errorf("merge base schema: %w", err) + } + return &doc, nil +} + const indexCodexPrompt = `Read the COMPLETE content of the file ./schema.graphql in your current working directory using your file-reading tool. Read the ENTIRE file (it is approximately 17KB and 824 lines) — do not truncate, do not skim, do not read only a portion. The file is a federated GraphQL supergraph SDL. Once the full schema is loaded into your context, output exactly this JSON object and nothing else: @@ -368,14 +367,14 @@ Once the full schema is loaded into your context, output exactly this JSON objec Do not include preamble, prose, markdown fences, or commentary.` -func buildCodexPrompt(prompts []string) string { +func buildCodexPrompt(prompt string) string { var b strings.Builder b.WriteString("You already loaded the federated GraphQL supergraph SDL from\n") b.WriteString("./schema.graphql earlier in this session. Use it as the source of\n") b.WriteString("truth — do not re-read the file.\n\n") - b.WriteString("For each user prompt below, generate ONE corresponding GraphQL\n") - b.WriteString("operation (query or mutation) that fulfills the prompt against\n") - b.WriteString("the schema. Return one operation per prompt, in the same order.\n\n") + b.WriteString("Generate one or more GraphQL operations (query or mutation) that\n") + b.WriteString("together fulfill the user prompt below against the schema. Each\n") + b.WriteString("operation must be self-contained and named.\n\n") b.WriteString("PARAMETERIZATION REQUIREMENT (load-bearing):\n") b.WriteString("Whenever an argument's value depends on the caller's intent (an id,\n") b.WriteString("a filter, a name, a tag, a limit, etc.), you MUST declare a GraphQL\n") @@ -387,28 +386,30 @@ func buildCodexPrompt(prompts []string) string { b.WriteString("(for example, 'list ALL employees' might pass no args at all). Variable\n") b.WriteString("types must match the schema, including non-null bangs.\n\n") b.WriteString("OUTPUT FORMAT (strict, machine-parsed):\n") - b.WriteString("- Output a single JSON object with one key: \"operations\" (array).\n") - b.WriteString("- Each operation has keys: name (camelCase), body (operation\n") - b.WriteString(" source text starting with 'query (...)' or\n") - b.WriteString(" 'mutation (...)' when variables are declared, or\n") - b.WriteString(" 'query { ... }' / 'mutation { ... }' when truly\n") - b.WriteString(" variable-free), kind ('query' or 'mutation'), description\n") - b.WriteString(" (one short sentence).\n") - b.WriteString("- operations.length MUST equal the number of user prompts below,\n") - b.WriteString(" in the same order.\n") + b.WriteString("- Output a single JSON object with these keys:\n") + b.WriteString(" - queries: array of objects, each with keys:\n") + b.WriteString(" description (one short sentence describing what this query does),\n") + b.WriteString(" document (operation source text starting with 'query (...)'\n") + b.WriteString(" or 'mutation (...)' when variables are declared, or\n") + b.WriteString(" 'query { ... }' / 'mutation { ... }' when\n") + b.WriteString(" variable-free),\n") + b.WriteString(" operation_name (the name parsed from the document),\n") + b.WriteString(" operation_type (\"query\" or \"mutation\").\n") + b.WriteString(" - unsatisfied: array of {\"reason\": \"...\"} for any requirement that\n") + b.WriteString(" cannot be satisfied against the schema. Empty array if everything\n") + b.WriteString(" could be satisfied.\n") + b.WriteString(" - truncated: boolean. true only if you ran out of room before\n") + b.WriteString(" committing every requirement.\n") b.WriteString("- No prose, no preamble, no markdown fences.\n\n") - b.WriteString("USER PROMPTS:\n") - for _, prompt := range prompts { - b.WriteString("- ") - b.WriteString(prompt) - b.WriteByte('\n') - } + b.WriteString("USER PROMPT:\n") + b.WriteString(prompt) + b.WriteByte('\n') return b.String() } // runCodexIndex performs the one-time pre-warm: codex reads schema.graphql in // schemaDir and a session is started. The session id (UUID) is parsed from -// codex's first JSONL event and returned so subsequent Search calls can resume +// codex's first JSONL event and returned so subsequent calls can resume // the same session. func (s *yokoService) runCodexIndex(ctx context.Context, schemaDir string) (string, error) { cmdCtx, cancel := context.WithTimeout(ctx, s.codexTimeout) @@ -457,13 +458,13 @@ func (s *yokoService) runCodexIndex(ctx context.Context, schemaDir string) (stri } // runCodexResume resumes the previously-warmed session and runs the user -// prompts. The agent's last message (a JSON object of operations) is captured -// via `--output-last-message` and returned for parsing. +// prompt. The agent's last message (a JSON resolution) is captured via +// `--output-last-message` and returned for parsing. func (s *yokoService) runCodexResume(ctx context.Context, sessionID, prompt string) ([]byte, error) { cmdCtx, cancel := context.WithTimeout(ctx, s.codexTimeout) defer cancel() - outFile, err := os.CreateTemp("", "yoko-search-out-*.txt") + outFile, err := os.CreateTemp("", "yoko-generate-out-*.txt") if err != nil { return nil, fmt.Errorf("create output temp file: %w", err) } @@ -532,34 +533,33 @@ func parseThreadID(stdout []byte) (string, error) { return ev.ThreadID, nil } -func parseCodexOperations(stdout []byte) ([]*yokov1.GeneratedOperation, error) { +func parseCodexResolution(stdout []byte) (*yokov1.Resolution, error) { payload := extractJSONObject(stdout) - var parsed codexOutput + var parsed codexResolution if err := json.Unmarshal(payload, &parsed); err != nil { return nil, err } - ops := make([]*yokov1.GeneratedOperation, 0, len(parsed.Operations)) - for _, op := range parsed.Operations { - ops = append(ops, &yokov1.GeneratedOperation{ - Name: op.Name, - Body: op.Body, - Kind: operationKind(op.Kind), - Description: op.Description, + queries := make([]*yokov1.ResolvedQuery, 0, len(parsed.Queries)) + for _, q := range parsed.Queries { + queries = append(queries, &yokov1.ResolvedQuery{ + Description: q.Description, + Document: q.Document, + OperationName: q.OperationName, + OperationType: strings.ToLower(strings.TrimSpace(q.OperationType)), }) } - return ops, nil -} -func operationKind(kind string) yokov1.OperationKind { - switch strings.ToLower(kind) { - case "query": - return yokov1.OperationKind_OPERATION_KIND_QUERY - case "mutation": - return yokov1.OperationKind_OPERATION_KIND_MUTATION - default: - return yokov1.OperationKind_OPERATION_KIND_UNSPECIFIED + unsatisfied := make([]*yokov1.Unsatisfied, 0, len(parsed.Unsatisfied)) + for _, u := range parsed.Unsatisfied { + unsatisfied = append(unsatisfied, &yokov1.Unsatisfied{Reason: u.Reason}) } + + return &yokov1.Resolution{ + Queries: queries, + Unsatisfied: unsatisfied, + Truncated: parsed.Truncated, + }, nil } // extractJSONObject returns the substring from the first '{' to the last '}' diff --git a/demo/code-mode/yoko-mock/main_test.go b/demo/code-mode/yoko-mock/main_test.go index 61b0ea3d4f..e5006f3735 100644 --- a/demo/code-mode/yoko-mock/main_test.go +++ b/demo/code-mode/yoko-mock/main_test.go @@ -18,69 +18,122 @@ import ( "google.golang.org/protobuf/proto" ) -func TestIndexThenSearchReturnsGeneratedOperations(t *testing.T) { +const testSDL = `type Query { viewer: User } type User { id: ID! }` + +func TestIndexSchemaThenGenerateQueryReturnsResolvedQuery(t *testing.T) { + writeFakeCodex(t, + `{"type":"thread.started","thread_id":"fake-thread"}`, + `{"queries":[{"description":"Fetches the current viewer.","document":"query getViewer { viewer { id } }","operation_name":"getViewer","operation_type":"query"}],"unsatisfied":[],"truncated":false}`, + ) + client := newTestClient(t) + + indexResp, err := client.IndexSchema(context.Background(), connect.NewRequest(&yokov1.IndexSchemaRequest{ + Sdl: testSDL, + })) + require.NoError(t, err) + + resp, err := client.GenerateQuery(context.Background(), connect.NewRequest(&yokov1.GenerateQueryRequest{ + SchemaId: indexResp.Msg.GetSchemaId(), + Prompt: "get the viewer", + })) + require.NoError(t, err) + + expected := &yokov1.GenerateQueryResponse{ + Resolution: &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{{ + Description: "Fetches the current viewer.", + Document: "query getViewer { viewer { id } }", + OperationName: "getViewer", + OperationType: "query", + VariablesSchema: `{"type":"object","properties":{}}`, + }}, + }, + } + assert.Equal(t, normalizeGenerateResponse(t, expected), normalizeGenerateResponse(t, resp.Msg)) +} + +func TestGenerateQueryDerivesVariablesSchemaFromOperation(t *testing.T) { + writeFakeCodex(t, + `{"type":"thread.started","thread_id":"fake-thread"}`, + `{"queries":[{"description":"Fetch viewer by id.","document":"query GetViewer($id: ID!) { viewer(id: $id) { id } }","operation_name":"GetViewer","operation_type":"query"}],"unsatisfied":[],"truncated":false}`, + ) + client := newTestClient(t) + + indexResp, err := client.IndexSchema(context.Background(), connect.NewRequest(&yokov1.IndexSchemaRequest{ + Sdl: `type Query { viewer(id: ID!): User } type User { id: ID! }`, + })) + require.NoError(t, err) + + resp, err := client.GenerateQuery(context.Background(), connect.NewRequest(&yokov1.GenerateQueryRequest{ + SchemaId: indexResp.Msg.GetSchemaId(), + Prompt: "viewer", + })) + require.NoError(t, err) + + queries := resp.Msg.GetResolution().GetQueries() + require.Len(t, queries, 1) + assert.Equal(t, `{"type":"object","properties":{"id":{"type":"string"}},"required":["id"]}`, queries[0].GetVariablesSchema()) +} + +func TestGenerateQueryForwardsUnsatisfiedAndTruncated(t *testing.T) { writeFakeCodex(t, `{"type":"thread.started","thread_id":"fake-thread"}`, - `{"operations":[{"name":"getViewer","body":"query getViewer { viewer { id } }","kind":"query","description":"Fetches the current viewer."}]}`, + `{"queries":[],"unsatisfied":[{"reason":"no field on the schema carries that filter dimension"}],"truncated":true}`, ) client := newTestClient(t) - indexResp, err := client.Index(context.Background(), connect.NewRequest(&yokov1.IndexRequest{ - SchemaSdl: "type Query { viewer: User } type User { id: ID! }", + indexResp, err := client.IndexSchema(context.Background(), connect.NewRequest(&yokov1.IndexSchemaRequest{ + Sdl: testSDL, })) require.NoError(t, err) - searchResp, err := client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ - SchemaId: indexResp.Msg.GetSchemaId(), - Prompts: []string{"get the viewer"}, - SessionId: "session-1", + resp, err := client.GenerateQuery(context.Background(), connect.NewRequest(&yokov1.GenerateQueryRequest{ + SchemaId: indexResp.Msg.GetSchemaId(), + Prompt: "viewer filtered by some unknown thing", })) require.NoError(t, err) - expected := &yokov1.SearchResponse{ - Operations: []*yokov1.GeneratedOperation{ - { - Name: "getViewer", - Body: "query getViewer { viewer { id } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, - Description: "Fetches the current viewer.", - }, + expected := &yokov1.GenerateQueryResponse{ + Resolution: &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{}, + Unsatisfied: []*yokov1.Unsatisfied{{Reason: "no field on the schema carries that filter dimension"}}, + Truncated: true, }, } - assert.Equal(t, normalizeSearchResponse(t, expected), normalizeSearchResponse(t, searchResp.Msg)) + assert.Equal(t, normalizeGenerateResponse(t, expected), normalizeGenerateResponse(t, resp.Msg)) } -func TestSearchUnknownSchemaIDReturnsNotFound(t *testing.T) { +func TestGenerateQueryUnknownSchemaIDReturnsNotFound(t *testing.T) { writeFakeCodex(t, `{"type":"thread.started","thread_id":"fake-thread"}`, - `{"operations":[]}`, + `{"queries":[],"unsatisfied":[],"truncated":false}`, ) client := newTestClient(t) - _, err := client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ + _, err := client.GenerateQuery(context.Background(), connect.NewRequest(&yokov1.GenerateQueryRequest{ SchemaId: "unknown", - Prompts: []string{"get the viewer"}, + Prompt: "get the viewer", })) require.Error(t, err) assert.Equal(t, connect.CodeNotFound, connect.CodeOf(err)) } -func TestSearchBadJSONReturnsInternal(t *testing.T) { +func TestGenerateQueryBadJSONReturnsInternal(t *testing.T) { writeFakeCodex(t, `{"type":"thread.started","thread_id":"fake-thread"}`, `not json`, ) client := newTestClient(t) - indexResp, err := client.Index(context.Background(), connect.NewRequest(&yokov1.IndexRequest{ - SchemaSdl: "type Query { viewer: ID! }", + indexResp, err := client.IndexSchema(context.Background(), connect.NewRequest(&yokov1.IndexSchemaRequest{ + Sdl: testSDL, })) require.NoError(t, err) - _, err = client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ + _, err = client.GenerateQuery(context.Background(), connect.NewRequest(&yokov1.GenerateQueryRequest{ SchemaId: indexResp.Msg.GetSchemaId(), - Prompts: []string{"get the viewer"}, + Prompt: "get the viewer", })) require.Error(t, err) @@ -104,13 +157,13 @@ func newTestClient(t *testing.T) yokov1connect.YokoServiceClient { } // writeFakeCodex installs a stub `codex` binary on PATH that mocks both the -// initial `codex exec` (Index pre-warm) and `codex exec resume` (Search) calls. -// The stub detects "resume" in its argv to switch modes. +// initial `codex exec` (IndexSchema pre-warm) and `codex exec resume` +// (GenerateQuery) calls. The stub detects "resume" in its argv to switch modes. // -// - indexStdout is printed to stdout for the Index call (e.g. a JSONL line -// like {"type":"thread.started","thread_id":"..."}). -// - resumeMessage is written to the file passed via -o for the Search -// call (codex's --output-last-message contract). +// - indexStdout is printed to stdout for the IndexSchema call (e.g. a JSONL +// line like {"type":"thread.started","thread_id":"..."}). +// - resumeMessage is written to the file passed via -o for the +// GenerateQuery call (codex's --output-last-message contract). func writeFakeCodex(t *testing.T, indexStdout, resumeMessage string) { t.Helper() @@ -127,7 +180,7 @@ func writeFakeCodex(t *testing.T, indexStdout, resumeMessage string) { path := filepath.Join(dir, name) var script string if runtime.GOOS == "windows" { - // Minimal Windows fallback — only Index path is exercised in CI on Unix. + // Minimal Windows fallback — only IndexSchema path is exercised in CI on Unix. script = "@echo off\r\ntype \"" + indexFile + "\"\r\n" } else { script = "#!/bin/sh\n" + @@ -152,12 +205,12 @@ func writeFakeCodex(t *testing.T, indexStdout, resumeMessage string) { var _ http.Handler = (*http.ServeMux)(nil) -func normalizeSearchResponse(t *testing.T, resp *yokov1.SearchResponse) *yokov1.SearchResponse { +func normalizeGenerateResponse(t *testing.T, resp *yokov1.GenerateQueryResponse) *yokov1.GenerateQueryResponse { t.Helper() data, err := proto.Marshal(resp) require.NoError(t, err) - normalized := &yokov1.SearchResponse{} + normalized := &yokov1.GenerateQueryResponse{} require.NoError(t, proto.Unmarshal(data, normalized)) return normalized } diff --git a/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto b/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto index 3add37193b..5e6592a56a 100644 --- a/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto +++ b/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto @@ -1,84 +1,85 @@ syntax = "proto3"; -package wundergraph.cosmo.code_mode.yoko.v1; - -// Yoko generates GraphQL operations for natural-language prompts. -// -// Two-step flow: -// 1. Index(schema_sdl) -> schema_id. Idempotent: the same SDL -// always returns the same schema_id for as long as the index -// is retained. The router caches schema_id and re-indexes only -// on supergraph reload (or when Search returns NOT_FOUND). -// 2. Search(prompts, schema_id) -> operations. Yoko owns prompt -// fan-out, partial-failure handling, cross-prompt deduplication, -// and ranking. -// -// The router never sends a schema hash. Yoko is the sole authority -// on schema identity; the router only sends raw SDL on Index and an -// opaque id on Search. +package yoko.v1; + +import "buf/validate/validate.proto"; + +option go_package = "github.com/wundergraph/yoko/gen/yoko/v1;yokov1"; + +// YokoService turns natural-language prompts into validated GraphQL +// operations against an indexed schema. Clients first index a schema +// (returning a stable schema_id) and then call GenerateQuery with that +// id and a prompt to receive one or more compiled operations. service YokoService { - rpc Index(IndexRequest) returns (IndexResponse); - rpc Search(SearchRequest) returns (SearchResponse); + // IndexSchema parses, enriches, embeds and indexes a GraphQL SDL. + // Returns the deterministic schema_id callers pass to GenerateQuery. + rpc IndexSchema(IndexSchemaRequest) returns (IndexSchemaResponse); + + // GenerateQuery turns a natural-language prompt into one or more + // compiled GraphQL operations against the previously indexed schema. + rpc GenerateQuery(GenerateQueryRequest) returns (GenerateQueryResponse); } -message IndexRequest { - // The supergraph SDL to index. Sent in full on every Index call; - // Yoko deduplicates internally and is free to short-circuit when - // the SDL is already known. - string schema_sdl = 1; +message IndexSchemaRequest { + // GraphQL Schema Definition Language (SDL) for the target API. + // Must contain at least one non-whitespace character. + string sdl = 1 [(buf.validate.field).string.pattern = "\\S"]; } -message IndexResponse { - // Opaque, Yoko-assigned identifier for this schema. Stable for as - // long as Yoko retains the index. Subsequent Search calls pass this - // back instead of the full SDL. Idempotent: the same SDL returns - // the same schema_id. +message IndexSchemaResponse { + // Stable id derived from the indexed SDL; pass to GenerateQuery. string schema_id = 1; } -message SearchRequest { - // Batch of natural-language prompts. Bounded at 20 by the host. - repeated string prompts = 1; +message GenerateQueryRequest { + // schema_id from a prior IndexSchema call. + string schema_id = 1 [(buf.validate.field).string.pattern = "\\S"]; - // Identifier returned by a prior Index call. If Yoko no longer - // recognizes the id (e.g. eviction, restart), it MUST return the - // Connect error code NOT_FOUND; the router re-indexes and retries - // the call exactly once. - string schema_id = 2; + // Natural-language description of what the caller wants to fetch. + string prompt = 2 [(buf.validate.field).string.pattern = "\\S"]; +} - // Opaque MCP session ID for telemetry correlation only. - // Yoko MUST NOT use this for stateful behavior — sessions are owned - // by the router. - string session_id = 3; +message GenerateQueryResponse { + Resolution resolution = 1; } -message SearchResponse { - // Operations across all prompts, already deduplicated and ranked. - // Order is significant: earlier entries rank higher and are preferred - // when bundle truncation drops from the tail. - repeated GeneratedOperation operations = 1; +message Resolution { + // One entry per produced query; each is a self-contained operation + // with a natural-language description of what it does. + repeated ResolvedQuery queries = 1; + + // One entry per requirement we could not satisfy; each carries a + // natural-language reason. + repeated Unsatisfied unsatisfied = 2; + + // True when the propose agent ran out of turns before committing + // every requirement; clients may want to retry with a tighter prompt. + bool truncated = 3; } -message GeneratedOperation { - // Suggested operation name (camelCase preferred). The host applies - // its own identifier normalization and in-session collision-suffix - // logic on top of this — see §6. - string name = 1; +message ResolvedQuery { + // One short user-facing sentence describing what this query does. + string description = 1; + + // GraphQL operation document — exactly one named operation. + string document = 2; - // GraphQL operation body (query or mutation source text). - string body = 2; + // Operation name parsed from the document (e.g. "GetUserPosts"). + string operation_name = 3; - // Operation kind. Subscriptions are out of scope; if Yoko returns - // one, the host drops it with a single warn log. - OperationKind kind = 3; + // One of "query", "mutation", "subscription". + string operation_type = 4; - // Human-readable description, surfaced as JSDoc on the typed - // `tools.` signature in the rendered bundle. - string description = 4; + // JSON Schema for the operation's $variables object, derived + // statically from the document. Carried as a JSON-encoded string so + // JSON clients see a readable schema (a `bytes` field would surface + // as base64 over JSON transport). + string variables_schema = 5; } -enum OperationKind { - OPERATION_KIND_UNSPECIFIED = 0; - OPERATION_KIND_QUERY = 1; - OPERATION_KIND_MUTATION = 2; +message Unsatisfied { + // Natural-language explanation of why this requirement could not + // be satisfied (e.g. "no field on the schema carries that filter + // dimension"). + string reason = 1; } diff --git a/router/gen/proto/buf/validate/validate.pb.go b/router/gen/proto/buf/validate/validate.pb.go new file mode 100644 index 0000000000..e62b840677 --- /dev/null +++ b/router/gen/proto/buf/validate/validate.pb.go @@ -0,0 +1,9165 @@ +// Copyright 2023-2026 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc (unknown) +// source: buf/validate/validate.proto + +// [Protovalidate](https://protovalidate.com/) is the semantic validation library for Protobuf. +// It provides standard annotations to validate common rules on messages and fields, as well as the ability to use [CEL](https://cel.dev) to write custom rules. +// It's the next generation of [protoc-gen-validate](https://github.com/bufbuild/protoc-gen-validate). +// +// This package provides the options, messages, and enums that power Protovalidate. +// Apply its options to messages, fields, and oneofs in your Protobuf schemas to add validation rules: +// +// ```proto +// message User { +// string id = 1 [(buf.validate.field).string.uuid = true]; +// string first_name = 2 [(buf.validate.field).string.max_len = 64]; +// string last_name = 3 [(buf.validate.field).string.max_len = 64]; +// +// option (buf.validate.message).cel = { +// id: "first_name_requires_last_name" +// message: "last_name must be present if first_name is present" +// expression: "!has(this.first_name) || has(this.last_name)" +// }; +// } +// ``` +// +// These rules are enforced at runtime by language-specific libraries. +// See the [developer quickstart](https://protovalidate.com/quickstart/) to get started, or go directly to the runtime library for your language: +// [Go](https://github.com/bufbuild/protovalidate-go), +// [JavaScript/TypeScript](https://github.com/bufbuild/protovalidate-es), +// [Java](https://github.com/bufbuild/protovalidate-java), +// [Python](https://github.com/bufbuild/protovalidate-python), +// or [C++](https://github.com/bufbuild/protovalidate-cc). + +package validate + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" + durationpb "google.golang.org/protobuf/types/known/durationpb" + fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Specifies how `FieldRules.ignore` behaves, depending on the field's value, and +// whether the field tracks presence. +type Ignore int32 + +const ( + // Ignore rules if the field tracks presence and is unset. This is the default + // behavior. + // + // In proto3, only message fields, members of a Protobuf `oneof`, and fields + // with the `optional` label track presence. Consequently, the following fields + // are always validated, whether a value is set or not: + // + // ```proto + // syntax="proto3"; + // + // message RulesApply { + // string email = 1 [ + // (buf.validate.field).string.email = true + // ]; + // int32 age = 2 [ + // (buf.validate.field).int32.gt = 0 + // ]; + // repeated string labels = 3 [ + // (buf.validate.field).repeated.min_items = 1 + // ]; + // } + // + // ``` + // + // In contrast, the following fields track presence, and are only validated if + // a value is set: + // + // ```proto + // syntax="proto3"; + // + // message RulesApplyIfSet { + // optional string email = 1 [ + // (buf.validate.field).string.email = true + // ]; + // oneof ref { + // string reference = 2 [ + // (buf.validate.field).string.uuid = true + // ]; + // string name = 3 [ + // (buf.validate.field).string.min_len = 4 + // ]; + // } + // SomeMessage msg = 4 [ + // (buf.validate.field).cel = {/* ... */} + // ]; + // } + // + // ``` + // + // To ensure that such a field is set, add the `required` rule. + // + // To learn which fields track presence, see the + // [Field Presence cheat sheet](https://protobuf.dev/programming-guides/field_presence/#cheat). + Ignore_IGNORE_UNSPECIFIED Ignore = 0 + // Ignore rules if the field is unset, or set to the zero value. + // + // The zero value depends on the field type: + // - For strings, the zero value is the empty string. + // - For bytes, the zero value is empty bytes. + // - For bool, the zero value is false. + // - For numeric types, the zero value is zero. + // - For enums, the zero value is the first defined enum value. + // - For repeated fields, the zero is an empty list. + // - For map fields, the zero is an empty map. + // - For message fields, absence of the message (typically a null-value) is considered zero value. + // + // For fields that track presence (e.g. adding the `optional` label in proto3), + // this a no-op and behavior is the same as the default `IGNORE_UNSPECIFIED`. + Ignore_IGNORE_IF_ZERO_VALUE Ignore = 1 + // Always ignore rules, including the `required` rule. + // + // This is useful for ignoring the rules of a referenced message, or to + // temporarily ignore rules during development. + // + // ```proto + // + // message MyMessage { + // // The field's rules will always be ignored, including any validations + // // on value's fields. + // MyOtherMessage value = 1 [ + // (buf.validate.field).ignore = IGNORE_ALWAYS + // ]; + // } + // + // ``` + Ignore_IGNORE_ALWAYS Ignore = 3 +) + +// Enum value maps for Ignore. +var ( + Ignore_name = map[int32]string{ + 0: "IGNORE_UNSPECIFIED", + 1: "IGNORE_IF_ZERO_VALUE", + 3: "IGNORE_ALWAYS", + } + Ignore_value = map[string]int32{ + "IGNORE_UNSPECIFIED": 0, + "IGNORE_IF_ZERO_VALUE": 1, + "IGNORE_ALWAYS": 3, + } +) + +func (x Ignore) Enum() *Ignore { + p := new(Ignore) + *p = x + return p +} + +func (x Ignore) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Ignore) Descriptor() protoreflect.EnumDescriptor { + return file_buf_validate_validate_proto_enumTypes[0].Descriptor() +} + +func (Ignore) Type() protoreflect.EnumType { + return &file_buf_validate_validate_proto_enumTypes[0] +} + +func (x Ignore) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Do not use. +func (x *Ignore) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = Ignore(num) + return nil +} + +// Deprecated: Use Ignore.Descriptor instead. +func (Ignore) EnumDescriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{0} +} + +// KnownRegex contains some well-known patterns. +type KnownRegex int32 + +const ( + KnownRegex_KNOWN_REGEX_UNSPECIFIED KnownRegex = 0 + // HTTP header name as defined by [RFC 7230](https://datatracker.ietf.org/doc/html/rfc7230#section-3.2). + KnownRegex_KNOWN_REGEX_HTTP_HEADER_NAME KnownRegex = 1 + // HTTP header value as defined by [RFC 7230](https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.4). + KnownRegex_KNOWN_REGEX_HTTP_HEADER_VALUE KnownRegex = 2 +) + +// Enum value maps for KnownRegex. +var ( + KnownRegex_name = map[int32]string{ + 0: "KNOWN_REGEX_UNSPECIFIED", + 1: "KNOWN_REGEX_HTTP_HEADER_NAME", + 2: "KNOWN_REGEX_HTTP_HEADER_VALUE", + } + KnownRegex_value = map[string]int32{ + "KNOWN_REGEX_UNSPECIFIED": 0, + "KNOWN_REGEX_HTTP_HEADER_NAME": 1, + "KNOWN_REGEX_HTTP_HEADER_VALUE": 2, + } +) + +func (x KnownRegex) Enum() *KnownRegex { + p := new(KnownRegex) + *p = x + return p +} + +func (x KnownRegex) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (KnownRegex) Descriptor() protoreflect.EnumDescriptor { + return file_buf_validate_validate_proto_enumTypes[1].Descriptor() +} + +func (KnownRegex) Type() protoreflect.EnumType { + return &file_buf_validate_validate_proto_enumTypes[1] +} + +func (x KnownRegex) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Do not use. +func (x *KnownRegex) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = KnownRegex(num) + return nil +} + +// Deprecated: Use KnownRegex.Descriptor instead. +func (KnownRegex) EnumDescriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{1} +} + +// `Rule` represents a validation rule written in the Common Expression +// Language (CEL) syntax. Each Rule includes a unique identifier, an +// optional error message, and the CEL expression to evaluate. For more +// information, [see our documentation](https://buf.build/docs/protovalidate/schemas/custom-rules/). +// +// ```proto +// +// message Foo { +// option (buf.validate.message).cel = { +// id: "foo.bar" +// message: "bar must be greater than 0" +// expression: "this.bar > 0" +// }; +// int32 bar = 1; +// } +// +// ``` +type Rule struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `id` is a string that serves as a machine-readable name for this Rule. + // It should be unique within its scope, which could be either a message or a field. + Id *string `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"` + // `message` is an optional field that provides a human-readable error message + // for this Rule when the CEL expression evaluates to false. If a + // non-empty message is provided, any strings resulting from the CEL + // expression evaluation are ignored. + Message *string `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` + // `expression` is the actual CEL expression that will be evaluated for + // validation. This string must resolve to either a boolean or a string + // value. If the expression evaluates to false or a non-empty string, the + // validation is considered failed, and the message is rejected. + Expression *string `protobuf:"bytes,3,opt,name=expression" json:"expression,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Rule) Reset() { + *x = Rule{} + mi := &file_buf_validate_validate_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Rule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Rule) ProtoMessage() {} + +func (x *Rule) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Rule.ProtoReflect.Descriptor instead. +func (*Rule) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{0} +} + +func (x *Rule) GetId() string { + if x != nil && x.Id != nil { + return *x.Id + } + return "" +} + +func (x *Rule) GetMessage() string { + if x != nil && x.Message != nil { + return *x.Message + } + return "" +} + +func (x *Rule) GetExpression() string { + if x != nil && x.Expression != nil { + return *x.Expression + } + return "" +} + +// MessageRules represents validation rules that are applied to the entire message. +// It includes disabling options and a list of Rule messages representing Common Expression Language (CEL) validation rules. +type MessageRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `cel_expression` is a repeated field CEL expressions. Each expression specifies a validation + // rule to be applied to this message. These rules are written in Common Expression Language (CEL) syntax. + // + // This is a simplified form of the `cel` Rule field, where only `expression` is set. This allows for + // simpler syntax when defining CEL Rules where `id` and `message` derived from the `expression`. `id` will + // be same as the `expression`. + // + // For more information, [see our documentation](https://buf.build/docs/protovalidate/schemas/custom-rules/). + // + // ```proto + // + // message MyMessage { + // // The field `foo` must be greater than 42. + // option (buf.validate.message).cel_expression = "this.foo > 42"; + // // The field `foo` must be less than 84. + // option (buf.validate.message).cel_expression = "this.foo < 84"; + // optional int32 foo = 1; + // } + // + // ``` + CelExpression []string `protobuf:"bytes,5,rep,name=cel_expression,json=celExpression" json:"cel_expression,omitempty"` + // `cel` is a repeated field of type Rule. Each Rule specifies a validation rule to be applied to this message. + // These rules are written in Common Expression Language (CEL) syntax. For more information, + // [see our documentation](https://buf.build/docs/protovalidate/schemas/custom-rules/). + // + // ```proto + // + // message MyMessage { + // // The field `foo` must be greater than 42. + // option (buf.validate.message).cel = { + // id: "my_message.value", + // message: "must be greater than 42", + // expression: "this.foo > 42", + // }; + // optional int32 foo = 1; + // } + // + // ``` + Cel []*Rule `protobuf:"bytes,3,rep,name=cel" json:"cel,omitempty"` + // `oneof` is a repeated field of type MessageOneofRule that specifies a list of fields + // of which at most one can be present. If `required` is also specified, then exactly one + // of the specified fields _must_ be present. + // + // This will enforce oneof-like constraints with a few features not provided by + // actual Protobuf oneof declarations: + // 1. Repeated and map fields are allowed in this validation. In a Protobuf oneof, + // only scalar fields are allowed. + // 2. Fields with implicit presence are allowed. In a Protobuf oneof, all member + // fields have explicit presence. This means that, for the purpose of determining + // how many fields are set, explicitly setting such a field to its zero value is + // effectively the same as not setting it at all. + // 3. This will always generate validation errors for a message unmarshalled from + // serialized data that sets more than one field. With a Protobuf oneof, when + // multiple fields are present in the serialized form, earlier values are usually + // silently ignored when unmarshalling, with only the last field being set when + // unmarshalling completes. + // + // Note that adding a field to a `oneof` will also set the IGNORE_IF_ZERO_VALUE on the fields. This means + // only the field that is set will be validated and the unset fields are not validated according to the field rules. + // This behavior can be overridden by setting `ignore` against a field. + // + // ```proto + // + // message MyMessage { + // // Only one of `field1` or `field2` _can_ be present in this message. + // option (buf.validate.message).oneof = { fields: ["field1", "field2"] }; + // // Exactly one of `field3` or `field4` _must_ be present in this message. + // option (buf.validate.message).oneof = { fields: ["field3", "field4"], required: true }; + // string field1 = 1; + // bytes field2 = 2; + // bool field3 = 3; + // int32 field4 = 4; + // } + // + // ``` + Oneof []*MessageOneofRule `protobuf:"bytes,4,rep,name=oneof" json:"oneof,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MessageRules) Reset() { + *x = MessageRules{} + mi := &file_buf_validate_validate_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MessageRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageRules) ProtoMessage() {} + +func (x *MessageRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageRules.ProtoReflect.Descriptor instead. +func (*MessageRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{1} +} + +func (x *MessageRules) GetCelExpression() []string { + if x != nil { + return x.CelExpression + } + return nil +} + +func (x *MessageRules) GetCel() []*Rule { + if x != nil { + return x.Cel + } + return nil +} + +func (x *MessageRules) GetOneof() []*MessageOneofRule { + if x != nil { + return x.Oneof + } + return nil +} + +type MessageOneofRule struct { + state protoimpl.MessageState `protogen:"open.v1"` + // A list of field names to include in the oneof. All field names must be + // defined in the message. At least one field must be specified, and + // duplicates are not permitted. + Fields []string `protobuf:"bytes,1,rep,name=fields" json:"fields,omitempty"` + // If true, one of the fields specified _must_ be set. + Required *bool `protobuf:"varint,2,opt,name=required" json:"required,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MessageOneofRule) Reset() { + *x = MessageOneofRule{} + mi := &file_buf_validate_validate_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MessageOneofRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageOneofRule) ProtoMessage() {} + +func (x *MessageOneofRule) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageOneofRule.ProtoReflect.Descriptor instead. +func (*MessageOneofRule) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{2} +} + +func (x *MessageOneofRule) GetFields() []string { + if x != nil { + return x.Fields + } + return nil +} + +func (x *MessageOneofRule) GetRequired() bool { + if x != nil && x.Required != nil { + return *x.Required + } + return false +} + +// The `OneofRules` message type enables you to manage rules for +// oneof fields in your protobuf messages. +type OneofRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // If `required` is true, exactly one field of the oneof must be set. A + // validation error is returned if no fields in the oneof are set. Further rules + // should be placed on the fields themselves to ensure they are valid values, + // such as `min_len` or `gt`. + // + // ```proto + // + // message MyMessage { + // oneof value { + // // Either `a` or `b` must be set. If `a` is set, it must also be + // // non-empty; whereas if `b` is set, it can still be an empty string. + // option (buf.validate.oneof).required = true; + // string a = 1 [(buf.validate.field).string.min_len = 1]; + // string b = 2; + // } + // } + // + // ``` + Required *bool `protobuf:"varint,1,opt,name=required" json:"required,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OneofRules) Reset() { + *x = OneofRules{} + mi := &file_buf_validate_validate_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OneofRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OneofRules) ProtoMessage() {} + +func (x *OneofRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OneofRules.ProtoReflect.Descriptor instead. +func (*OneofRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{3} +} + +func (x *OneofRules) GetRequired() bool { + if x != nil && x.Required != nil { + return *x.Required + } + return false +} + +// FieldRules encapsulates the rules for each type of field. Depending on +// the field, the correct set should be used to ensure proper validations. +type FieldRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `cel_expression` is a repeated field CEL expressions. Each expression specifies a validation + // rule to be applied to this message. These rules are written in Common Expression Language (CEL) syntax. + // + // This is a simplified form of the `cel` Rule field, where only `expression` is set. This allows for + // simpler syntax when defining CEL Rules where `id` and `message` derived from the `expression`. `id` will + // be same as the `expression`. + // + // For more information, [see our documentation](https://buf.build/docs/protovalidate/schemas/custom-rules/). + // + // ```proto + // + // message MyMessage { + // // The field `value` must be greater than 42. + // optional int32 value = 1 [(buf.validate.field).cel_expression = "this > 42"]; + // } + // + // ``` + CelExpression []string `protobuf:"bytes,29,rep,name=cel_expression,json=celExpression" json:"cel_expression,omitempty"` + // `cel` is a repeated field used to represent a textual expression + // in the Common Expression Language (CEL) syntax. For more information, + // [see our documentation](https://buf.build/docs/protovalidate/schemas/custom-rules/). + // + // ```proto + // + // message MyMessage { + // // The field `value` must be greater than 42. + // optional int32 value = 1 [(buf.validate.field).cel = { + // id: "my_message.value", + // message: "must be greater than 42", + // expression: "this > 42", + // }]; + // } + // + // ``` + Cel []*Rule `protobuf:"bytes,23,rep,name=cel" json:"cel,omitempty"` + // If `required` is true, the field must be set. A validation error is returned + // if the field is not set. + // + // ```proto + // syntax="proto3"; + // + // message FieldsWithPresence { + // // Requires any string to be set, including the empty string. + // optional string link = 1 [ + // (buf.validate.field).required = true + // ]; + // // Requires true or false to be set. + // optional bool disabled = 2 [ + // (buf.validate.field).required = true + // ]; + // // Requires a message to be set, including the empty message. + // SomeMessage msg = 4 [ + // (buf.validate.field).required = true + // ]; + // } + // + // ``` + // + // All fields in the example above track presence. By default, Protovalidate + // ignores rules on those fields if no value is set. `required` ensures that + // the fields are set and valid. + // + // Fields that don't track presence are always validated by Protovalidate, + // whether they are set or not. It is not necessary to add `required`. It + // can be added to indicate that the field cannot be the zero value. + // + // ```proto + // syntax="proto3"; + // + // message FieldsWithoutPresence { + // // `string.email` always applies, even to an empty string. + // string link = 1 [ + // (buf.validate.field).string.email = true + // ]; + // // `repeated.min_items` always applies, even to an empty list. + // repeated string labels = 2 [ + // (buf.validate.field).repeated.min_items = 1 + // ]; + // // `required`, for fields that don't track presence, indicates + // // the value of the field can't be the zero value. + // int32 zero_value_not_allowed = 3 [ + // (buf.validate.field).required = true + // ]; + // } + // + // ``` + // + // To learn which fields track presence, see the + // [Field Presence cheat sheet](https://protobuf.dev/programming-guides/field_presence/#cheat). + // + // Note: While field rules can be applied to repeated items, map keys, and map + // values, the elements are always considered to be set. Consequently, + // specifying `repeated.items.required` is redundant. + Required *bool `protobuf:"varint,25,opt,name=required" json:"required,omitempty"` + // Ignore validation rules on the field if its value matches the specified + // criteria. See the `Ignore` enum for details. + // + // ```proto + // + // message UpdateRequest { + // // The uri rule only applies if the field is not an empty string. + // string url = 1 [ + // (buf.validate.field).ignore = IGNORE_IF_ZERO_VALUE, + // (buf.validate.field).string.uri = true + // ]; + // } + // + // ``` + Ignore *Ignore `protobuf:"varint,27,opt,name=ignore,enum=buf.validate.Ignore" json:"ignore,omitempty"` + // Types that are valid to be assigned to Type: + // + // *FieldRules_Float + // *FieldRules_Double + // *FieldRules_Int32 + // *FieldRules_Int64 + // *FieldRules_Uint32 + // *FieldRules_Uint64 + // *FieldRules_Sint32 + // *FieldRules_Sint64 + // *FieldRules_Fixed32 + // *FieldRules_Fixed64 + // *FieldRules_Sfixed32 + // *FieldRules_Sfixed64 + // *FieldRules_Bool + // *FieldRules_String_ + // *FieldRules_Bytes + // *FieldRules_Enum + // *FieldRules_Repeated + // *FieldRules_Map + // *FieldRules_Any + // *FieldRules_Duration + // *FieldRules_FieldMask + // *FieldRules_Timestamp + Type isFieldRules_Type `protobuf_oneof:"type"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FieldRules) Reset() { + *x = FieldRules{} + mi := &file_buf_validate_validate_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FieldRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FieldRules) ProtoMessage() {} + +func (x *FieldRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FieldRules.ProtoReflect.Descriptor instead. +func (*FieldRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{4} +} + +func (x *FieldRules) GetCelExpression() []string { + if x != nil { + return x.CelExpression + } + return nil +} + +func (x *FieldRules) GetCel() []*Rule { + if x != nil { + return x.Cel + } + return nil +} + +func (x *FieldRules) GetRequired() bool { + if x != nil && x.Required != nil { + return *x.Required + } + return false +} + +func (x *FieldRules) GetIgnore() Ignore { + if x != nil && x.Ignore != nil { + return *x.Ignore + } + return Ignore_IGNORE_UNSPECIFIED +} + +func (x *FieldRules) GetType() isFieldRules_Type { + if x != nil { + return x.Type + } + return nil +} + +func (x *FieldRules) GetFloat() *FloatRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Float); ok { + return x.Float + } + } + return nil +} + +func (x *FieldRules) GetDouble() *DoubleRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Double); ok { + return x.Double + } + } + return nil +} + +func (x *FieldRules) GetInt32() *Int32Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Int32); ok { + return x.Int32 + } + } + return nil +} + +func (x *FieldRules) GetInt64() *Int64Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Int64); ok { + return x.Int64 + } + } + return nil +} + +func (x *FieldRules) GetUint32() *UInt32Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Uint32); ok { + return x.Uint32 + } + } + return nil +} + +func (x *FieldRules) GetUint64() *UInt64Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Uint64); ok { + return x.Uint64 + } + } + return nil +} + +func (x *FieldRules) GetSint32() *SInt32Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Sint32); ok { + return x.Sint32 + } + } + return nil +} + +func (x *FieldRules) GetSint64() *SInt64Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Sint64); ok { + return x.Sint64 + } + } + return nil +} + +func (x *FieldRules) GetFixed32() *Fixed32Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Fixed32); ok { + return x.Fixed32 + } + } + return nil +} + +func (x *FieldRules) GetFixed64() *Fixed64Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Fixed64); ok { + return x.Fixed64 + } + } + return nil +} + +func (x *FieldRules) GetSfixed32() *SFixed32Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Sfixed32); ok { + return x.Sfixed32 + } + } + return nil +} + +func (x *FieldRules) GetSfixed64() *SFixed64Rules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Sfixed64); ok { + return x.Sfixed64 + } + } + return nil +} + +func (x *FieldRules) GetBool() *BoolRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Bool); ok { + return x.Bool + } + } + return nil +} + +func (x *FieldRules) GetString_() *StringRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_String_); ok { + return x.String_ + } + } + return nil +} + +func (x *FieldRules) GetBytes() *BytesRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Bytes); ok { + return x.Bytes + } + } + return nil +} + +func (x *FieldRules) GetEnum() *EnumRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Enum); ok { + return x.Enum + } + } + return nil +} + +func (x *FieldRules) GetRepeated() *RepeatedRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Repeated); ok { + return x.Repeated + } + } + return nil +} + +func (x *FieldRules) GetMap() *MapRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Map); ok { + return x.Map + } + } + return nil +} + +func (x *FieldRules) GetAny() *AnyRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Any); ok { + return x.Any + } + } + return nil +} + +func (x *FieldRules) GetDuration() *DurationRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Duration); ok { + return x.Duration + } + } + return nil +} + +func (x *FieldRules) GetFieldMask() *FieldMaskRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_FieldMask); ok { + return x.FieldMask + } + } + return nil +} + +func (x *FieldRules) GetTimestamp() *TimestampRules { + if x != nil { + if x, ok := x.Type.(*FieldRules_Timestamp); ok { + return x.Timestamp + } + } + return nil +} + +type isFieldRules_Type interface { + isFieldRules_Type() +} + +type FieldRules_Float struct { + // Scalar Field Types + Float *FloatRules `protobuf:"bytes,1,opt,name=float,oneof"` +} + +type FieldRules_Double struct { + Double *DoubleRules `protobuf:"bytes,2,opt,name=double,oneof"` +} + +type FieldRules_Int32 struct { + Int32 *Int32Rules `protobuf:"bytes,3,opt,name=int32,oneof"` +} + +type FieldRules_Int64 struct { + Int64 *Int64Rules `protobuf:"bytes,4,opt,name=int64,oneof"` +} + +type FieldRules_Uint32 struct { + Uint32 *UInt32Rules `protobuf:"bytes,5,opt,name=uint32,oneof"` +} + +type FieldRules_Uint64 struct { + Uint64 *UInt64Rules `protobuf:"bytes,6,opt,name=uint64,oneof"` +} + +type FieldRules_Sint32 struct { + Sint32 *SInt32Rules `protobuf:"bytes,7,opt,name=sint32,oneof"` +} + +type FieldRules_Sint64 struct { + Sint64 *SInt64Rules `protobuf:"bytes,8,opt,name=sint64,oneof"` +} + +type FieldRules_Fixed32 struct { + Fixed32 *Fixed32Rules `protobuf:"bytes,9,opt,name=fixed32,oneof"` +} + +type FieldRules_Fixed64 struct { + Fixed64 *Fixed64Rules `protobuf:"bytes,10,opt,name=fixed64,oneof"` +} + +type FieldRules_Sfixed32 struct { + Sfixed32 *SFixed32Rules `protobuf:"bytes,11,opt,name=sfixed32,oneof"` +} + +type FieldRules_Sfixed64 struct { + Sfixed64 *SFixed64Rules `protobuf:"bytes,12,opt,name=sfixed64,oneof"` +} + +type FieldRules_Bool struct { + Bool *BoolRules `protobuf:"bytes,13,opt,name=bool,oneof"` +} + +type FieldRules_String_ struct { + String_ *StringRules `protobuf:"bytes,14,opt,name=string,oneof"` +} + +type FieldRules_Bytes struct { + Bytes *BytesRules `protobuf:"bytes,15,opt,name=bytes,oneof"` +} + +type FieldRules_Enum struct { + // Complex Field Types + Enum *EnumRules `protobuf:"bytes,16,opt,name=enum,oneof"` +} + +type FieldRules_Repeated struct { + Repeated *RepeatedRules `protobuf:"bytes,18,opt,name=repeated,oneof"` +} + +type FieldRules_Map struct { + Map *MapRules `protobuf:"bytes,19,opt,name=map,oneof"` +} + +type FieldRules_Any struct { + // Well-Known Field Types + Any *AnyRules `protobuf:"bytes,20,opt,name=any,oneof"` +} + +type FieldRules_Duration struct { + Duration *DurationRules `protobuf:"bytes,21,opt,name=duration,oneof"` +} + +type FieldRules_FieldMask struct { + FieldMask *FieldMaskRules `protobuf:"bytes,28,opt,name=field_mask,json=fieldMask,oneof"` +} + +type FieldRules_Timestamp struct { + Timestamp *TimestampRules `protobuf:"bytes,22,opt,name=timestamp,oneof"` +} + +func (*FieldRules_Float) isFieldRules_Type() {} + +func (*FieldRules_Double) isFieldRules_Type() {} + +func (*FieldRules_Int32) isFieldRules_Type() {} + +func (*FieldRules_Int64) isFieldRules_Type() {} + +func (*FieldRules_Uint32) isFieldRules_Type() {} + +func (*FieldRules_Uint64) isFieldRules_Type() {} + +func (*FieldRules_Sint32) isFieldRules_Type() {} + +func (*FieldRules_Sint64) isFieldRules_Type() {} + +func (*FieldRules_Fixed32) isFieldRules_Type() {} + +func (*FieldRules_Fixed64) isFieldRules_Type() {} + +func (*FieldRules_Sfixed32) isFieldRules_Type() {} + +func (*FieldRules_Sfixed64) isFieldRules_Type() {} + +func (*FieldRules_Bool) isFieldRules_Type() {} + +func (*FieldRules_String_) isFieldRules_Type() {} + +func (*FieldRules_Bytes) isFieldRules_Type() {} + +func (*FieldRules_Enum) isFieldRules_Type() {} + +func (*FieldRules_Repeated) isFieldRules_Type() {} + +func (*FieldRules_Map) isFieldRules_Type() {} + +func (*FieldRules_Any) isFieldRules_Type() {} + +func (*FieldRules_Duration) isFieldRules_Type() {} + +func (*FieldRules_FieldMask) isFieldRules_Type() {} + +func (*FieldRules_Timestamp) isFieldRules_Type() {} + +// PredefinedRules are custom rules that can be re-used with +// multiple fields. +type PredefinedRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `cel` is a repeated field used to represent a textual expression + // in the Common Expression Language (CEL) syntax. For more information, + // [see our documentation](https://buf.build/docs/protovalidate/schemas/predefined-rules/). + // + // ```proto + // + // message MyMessage { + // // The field `value` must be greater than 42. + // optional int32 value = 1 [(buf.validate.predefined).cel = { + // id: "my_message.value", + // message: "must be greater than 42", + // expression: "this > 42", + // }]; + // } + // + // ``` + Cel []*Rule `protobuf:"bytes,1,rep,name=cel" json:"cel,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PredefinedRules) Reset() { + *x = PredefinedRules{} + mi := &file_buf_validate_validate_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PredefinedRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PredefinedRules) ProtoMessage() {} + +func (x *PredefinedRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PredefinedRules.ProtoReflect.Descriptor instead. +func (*PredefinedRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{5} +} + +func (x *PredefinedRules) GetCel() []*Rule { + if x != nil { + return x.Cel + } + return nil +} + +// FloatRules describes the rules applied to `float` values. These +// rules may also be applied to the `google.protobuf.FloatValue` Well-Known-Type. +type FloatRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyFloat { + // // value must equal 42.0 + // float value = 1 [(buf.validate.field).float.const = 42.0]; + // } + // + // ``` + Const *float32 `protobuf:"fixed32,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *FloatRules_Lt + // *FloatRules_Lte + LessThan isFloatRules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *FloatRules_Gt + // *FloatRules_Gte + GreaterThan isFloatRules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message + // is generated. + // + // ```proto + // + // message MyFloat { + // // must be in list [1.0, 2.0, 3.0] + // float value = 1 [(buf.validate.field).float = { in: [1.0, 2.0, 3.0] }]; + // } + // + // ``` + In []float32 `protobuf:"fixed32,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MyFloat { + // // value must not be in list [1.0, 2.0, 3.0] + // float value = 1 [(buf.validate.field).float = { not_in: [1.0, 2.0, 3.0] }]; + // } + // + // ``` + NotIn []float32 `protobuf:"fixed32,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `finite` requires the field value to be finite. If the field value is + // infinite or NaN, an error message is generated. + Finite *bool `protobuf:"varint,8,opt,name=finite" json:"finite,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyFloat { + // float value = 1 [ + // (buf.validate.field).float.example = 1.0, + // (buf.validate.field).float.example = inf + // ]; + // } + // + // ``` + Example []float32 `protobuf:"fixed32,9,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FloatRules) Reset() { + *x = FloatRules{} + mi := &file_buf_validate_validate_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FloatRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FloatRules) ProtoMessage() {} + +func (x *FloatRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FloatRules.ProtoReflect.Descriptor instead. +func (*FloatRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{6} +} + +func (x *FloatRules) GetConst() float32 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *FloatRules) GetLessThan() isFloatRules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *FloatRules) GetLt() float32 { + if x != nil { + if x, ok := x.LessThan.(*FloatRules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *FloatRules) GetLte() float32 { + if x != nil { + if x, ok := x.LessThan.(*FloatRules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *FloatRules) GetGreaterThan() isFloatRules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *FloatRules) GetGt() float32 { + if x != nil { + if x, ok := x.GreaterThan.(*FloatRules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *FloatRules) GetGte() float32 { + if x != nil { + if x, ok := x.GreaterThan.(*FloatRules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *FloatRules) GetIn() []float32 { + if x != nil { + return x.In + } + return nil +} + +func (x *FloatRules) GetNotIn() []float32 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *FloatRules) GetFinite() bool { + if x != nil && x.Finite != nil { + return *x.Finite + } + return false +} + +func (x *FloatRules) GetExample() []float32 { + if x != nil { + return x.Example + } + return nil +} + +type isFloatRules_LessThan interface { + isFloatRules_LessThan() +} + +type FloatRules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MyFloat { + // // must be less than 10.0 + // float value = 1 [(buf.validate.field).float.lt = 10.0]; + // } + // + // ``` + Lt float32 `protobuf:"fixed32,2,opt,name=lt,oneof"` +} + +type FloatRules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyFloat { + // // must be less than or equal to 10.0 + // float value = 1 [(buf.validate.field).float.lte = 10.0]; + // } + // + // ``` + Lte float32 `protobuf:"fixed32,3,opt,name=lte,oneof"` +} + +func (*FloatRules_Lt) isFloatRules_LessThan() {} + +func (*FloatRules_Lte) isFloatRules_LessThan() {} + +type isFloatRules_GreaterThan interface { + isFloatRules_GreaterThan() +} + +type FloatRules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyFloat { + // // must be greater than 5.0 [float.gt] + // float value = 1 [(buf.validate.field).float.gt = 5.0]; + // + // // must be greater than 5 and less than 10.0 [float.gt_lt] + // float other_value = 2 [(buf.validate.field).float = { gt: 5.0, lt: 10.0 }]; + // + // // must be greater than 10 or less than 5.0 [float.gt_lt_exclusive] + // float another_value = 3 [(buf.validate.field).float = { gt: 10.0, lt: 5.0 }]; + // } + // + // ``` + Gt float32 `protobuf:"fixed32,4,opt,name=gt,oneof"` +} + +type FloatRules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyFloat { + // // must be greater than or equal to 5.0 [float.gte] + // float value = 1 [(buf.validate.field).float.gte = 5.0]; + // + // // must be greater than or equal to 5.0 and less than 10.0 [float.gte_lt] + // float other_value = 2 [(buf.validate.field).float = { gte: 5.0, lt: 10.0 }]; + // + // // must be greater than or equal to 10.0 or less than 5.0 [float.gte_lt_exclusive] + // float another_value = 3 [(buf.validate.field).float = { gte: 10.0, lt: 5.0 }]; + // } + // + // ``` + Gte float32 `protobuf:"fixed32,5,opt,name=gte,oneof"` +} + +func (*FloatRules_Gt) isFloatRules_GreaterThan() {} + +func (*FloatRules_Gte) isFloatRules_GreaterThan() {} + +// DoubleRules describes the rules applied to `double` values. These +// rules may also be applied to the `google.protobuf.DoubleValue` Well-Known-Type. +type DoubleRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyDouble { + // // value must equal 42.0 + // double value = 1 [(buf.validate.field).double.const = 42.0]; + // } + // + // ``` + Const *float64 `protobuf:"fixed64,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *DoubleRules_Lt + // *DoubleRules_Lte + LessThan isDoubleRules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *DoubleRules_Gt + // *DoubleRules_Gte + GreaterThan isDoubleRules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MyDouble { + // // must be in list [1.0, 2.0, 3.0] + // double value = 1 [(buf.validate.field).double = { in: [1.0, 2.0, 3.0] }]; + // } + // + // ``` + In []float64 `protobuf:"fixed64,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MyDouble { + // // value must not be in list [1.0, 2.0, 3.0] + // double value = 1 [(buf.validate.field).double = { not_in: [1.0, 2.0, 3.0] }]; + // } + // + // ``` + NotIn []float64 `protobuf:"fixed64,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `finite` requires the field value to be finite. If the field value is + // infinite or NaN, an error message is generated. + Finite *bool `protobuf:"varint,8,opt,name=finite" json:"finite,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyDouble { + // double value = 1 [ + // (buf.validate.field).double.example = 1.0, + // (buf.validate.field).double.example = inf + // ]; + // } + // + // ``` + Example []float64 `protobuf:"fixed64,9,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DoubleRules) Reset() { + *x = DoubleRules{} + mi := &file_buf_validate_validate_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DoubleRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DoubleRules) ProtoMessage() {} + +func (x *DoubleRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DoubleRules.ProtoReflect.Descriptor instead. +func (*DoubleRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{7} +} + +func (x *DoubleRules) GetConst() float64 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *DoubleRules) GetLessThan() isDoubleRules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *DoubleRules) GetLt() float64 { + if x != nil { + if x, ok := x.LessThan.(*DoubleRules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *DoubleRules) GetLte() float64 { + if x != nil { + if x, ok := x.LessThan.(*DoubleRules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *DoubleRules) GetGreaterThan() isDoubleRules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *DoubleRules) GetGt() float64 { + if x != nil { + if x, ok := x.GreaterThan.(*DoubleRules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *DoubleRules) GetGte() float64 { + if x != nil { + if x, ok := x.GreaterThan.(*DoubleRules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *DoubleRules) GetIn() []float64 { + if x != nil { + return x.In + } + return nil +} + +func (x *DoubleRules) GetNotIn() []float64 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *DoubleRules) GetFinite() bool { + if x != nil && x.Finite != nil { + return *x.Finite + } + return false +} + +func (x *DoubleRules) GetExample() []float64 { + if x != nil { + return x.Example + } + return nil +} + +type isDoubleRules_LessThan interface { + isDoubleRules_LessThan() +} + +type DoubleRules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyDouble { + // // must be less than 10.0 + // double value = 1 [(buf.validate.field).double.lt = 10.0]; + // } + // + // ``` + Lt float64 `protobuf:"fixed64,2,opt,name=lt,oneof"` +} + +type DoubleRules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified value + // (field <= value). If the field value is greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MyDouble { + // // must be less than or equal to 10.0 + // double value = 1 [(buf.validate.field).double.lte = 10.0]; + // } + // + // ``` + Lte float64 `protobuf:"fixed64,3,opt,name=lte,oneof"` +} + +func (*DoubleRules_Lt) isDoubleRules_LessThan() {} + +func (*DoubleRules_Lte) isDoubleRules_LessThan() {} + +type isDoubleRules_GreaterThan interface { + isDoubleRules_GreaterThan() +} + +type DoubleRules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or `lte`, + // the range is reversed, and the field value must be outside the specified + // range. If the field value doesn't meet the required conditions, an error + // message is generated. + // + // ```proto + // + // message MyDouble { + // // must be greater than 5.0 [double.gt] + // double value = 1 [(buf.validate.field).double.gt = 5.0]; + // + // // must be greater than 5 and less than 10.0 [double.gt_lt] + // double other_value = 2 [(buf.validate.field).double = { gt: 5.0, lt: 10.0 }]; + // + // // must be greater than 10 or less than 5.0 [double.gt_lt_exclusive] + // double another_value = 3 [(buf.validate.field).double = { gt: 10.0, lt: 5.0 }]; + // } + // + // ``` + Gt float64 `protobuf:"fixed64,4,opt,name=gt,oneof"` +} + +type DoubleRules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyDouble { + // // must be greater than or equal to 5.0 [double.gte] + // double value = 1 [(buf.validate.field).double.gte = 5.0]; + // + // // must be greater than or equal to 5.0 and less than 10.0 [double.gte_lt] + // double other_value = 2 [(buf.validate.field).double = { gte: 5.0, lt: 10.0 }]; + // + // // must be greater than or equal to 10.0 or less than 5.0 [double.gte_lt_exclusive] + // double another_value = 3 [(buf.validate.field).double = { gte: 10.0, lt: 5.0 }]; + // } + // + // ``` + Gte float64 `protobuf:"fixed64,5,opt,name=gte,oneof"` +} + +func (*DoubleRules_Gt) isDoubleRules_GreaterThan() {} + +func (*DoubleRules_Gte) isDoubleRules_GreaterThan() {} + +// Int32Rules describes the rules applied to `int32` values. These +// rules may also be applied to the `google.protobuf.Int32Value` Well-Known-Type. +type Int32Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyInt32 { + // // value must equal 42 + // int32 value = 1 [(buf.validate.field).int32.const = 42]; + // } + // + // ``` + Const *int32 `protobuf:"varint,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *Int32Rules_Lt + // *Int32Rules_Lte + LessThan isInt32Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *Int32Rules_Gt + // *Int32Rules_Gte + GreaterThan isInt32Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MyInt32 { + // // must be in list [1, 2, 3] + // int32 value = 1 [(buf.validate.field).int32 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []int32 `protobuf:"varint,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error message + // is generated. + // + // ```proto + // + // message MyInt32 { + // // value must not be in list [1, 2, 3] + // int32 value = 1 [(buf.validate.field).int32 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []int32 `protobuf:"varint,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyInt32 { + // int32 value = 1 [ + // (buf.validate.field).int32.example = 1, + // (buf.validate.field).int32.example = -10 + // ]; + // } + // + // ``` + Example []int32 `protobuf:"varint,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Int32Rules) Reset() { + *x = Int32Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Int32Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Int32Rules) ProtoMessage() {} + +func (x *Int32Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Int32Rules.ProtoReflect.Descriptor instead. +func (*Int32Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{8} +} + +func (x *Int32Rules) GetConst() int32 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *Int32Rules) GetLessThan() isInt32Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *Int32Rules) GetLt() int32 { + if x != nil { + if x, ok := x.LessThan.(*Int32Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *Int32Rules) GetLte() int32 { + if x != nil { + if x, ok := x.LessThan.(*Int32Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *Int32Rules) GetGreaterThan() isInt32Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *Int32Rules) GetGt() int32 { + if x != nil { + if x, ok := x.GreaterThan.(*Int32Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *Int32Rules) GetGte() int32 { + if x != nil { + if x, ok := x.GreaterThan.(*Int32Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *Int32Rules) GetIn() []int32 { + if x != nil { + return x.In + } + return nil +} + +func (x *Int32Rules) GetNotIn() []int32 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *Int32Rules) GetExample() []int32 { + if x != nil { + return x.Example + } + return nil +} + +type isInt32Rules_LessThan interface { + isInt32Rules_LessThan() +} + +type Int32Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field + // < value). If the field value is equal to or greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyInt32 { + // // must be less than 10 + // int32 value = 1 [(buf.validate.field).int32.lt = 10]; + // } + // + // ``` + Lt int32 `protobuf:"varint,2,opt,name=lt,oneof"` +} + +type Int32Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyInt32 { + // // must be less than or equal to 10 + // int32 value = 1 [(buf.validate.field).int32.lte = 10]; + // } + // + // ``` + Lte int32 `protobuf:"varint,3,opt,name=lte,oneof"` +} + +func (*Int32Rules_Lt) isInt32Rules_LessThan() {} + +func (*Int32Rules_Lte) isInt32Rules_LessThan() {} + +type isInt32Rules_GreaterThan interface { + isInt32Rules_GreaterThan() +} + +type Int32Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyInt32 { + // // must be greater than 5 [int32.gt] + // int32 value = 1 [(buf.validate.field).int32.gt = 5]; + // + // // must be greater than 5 and less than 10 [int32.gt_lt] + // int32 other_value = 2 [(buf.validate.field).int32 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [int32.gt_lt_exclusive] + // int32 another_value = 3 [(buf.validate.field).int32 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt int32 `protobuf:"varint,4,opt,name=gt,oneof"` +} + +type Int32Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified value + // (exclusive). If the value of `gte` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyInt32 { + // // must be greater than or equal to 5 [int32.gte] + // int32 value = 1 [(buf.validate.field).int32.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [int32.gte_lt] + // int32 other_value = 2 [(buf.validate.field).int32 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [int32.gte_lt_exclusive] + // int32 another_value = 3 [(buf.validate.field).int32 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte int32 `protobuf:"varint,5,opt,name=gte,oneof"` +} + +func (*Int32Rules_Gt) isInt32Rules_GreaterThan() {} + +func (*Int32Rules_Gte) isInt32Rules_GreaterThan() {} + +// Int64Rules describes the rules applied to `int64` values. These +// rules may also be applied to the `google.protobuf.Int64Value` Well-Known-Type. +type Int64Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyInt64 { + // // value must equal 42 + // int64 value = 1 [(buf.validate.field).int64.const = 42]; + // } + // + // ``` + Const *int64 `protobuf:"varint,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *Int64Rules_Lt + // *Int64Rules_Lte + LessThan isInt64Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *Int64Rules_Gt + // *Int64Rules_Gte + GreaterThan isInt64Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MyInt64 { + // // must be in list [1, 2, 3] + // int64 value = 1 [(buf.validate.field).int64 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []int64 `protobuf:"varint,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MyInt64 { + // // value must not be in list [1, 2, 3] + // int64 value = 1 [(buf.validate.field).int64 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []int64 `protobuf:"varint,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyInt64 { + // int64 value = 1 [ + // (buf.validate.field).int64.example = 1, + // (buf.validate.field).int64.example = -10 + // ]; + // } + // + // ``` + Example []int64 `protobuf:"varint,9,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Int64Rules) Reset() { + *x = Int64Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Int64Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Int64Rules) ProtoMessage() {} + +func (x *Int64Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Int64Rules.ProtoReflect.Descriptor instead. +func (*Int64Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{9} +} + +func (x *Int64Rules) GetConst() int64 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *Int64Rules) GetLessThan() isInt64Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *Int64Rules) GetLt() int64 { + if x != nil { + if x, ok := x.LessThan.(*Int64Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *Int64Rules) GetLte() int64 { + if x != nil { + if x, ok := x.LessThan.(*Int64Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *Int64Rules) GetGreaterThan() isInt64Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *Int64Rules) GetGt() int64 { + if x != nil { + if x, ok := x.GreaterThan.(*Int64Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *Int64Rules) GetGte() int64 { + if x != nil { + if x, ok := x.GreaterThan.(*Int64Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *Int64Rules) GetIn() []int64 { + if x != nil { + return x.In + } + return nil +} + +func (x *Int64Rules) GetNotIn() []int64 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *Int64Rules) GetExample() []int64 { + if x != nil { + return x.Example + } + return nil +} + +type isInt64Rules_LessThan interface { + isInt64Rules_LessThan() +} + +type Int64Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MyInt64 { + // // must be less than 10 + // int64 value = 1 [(buf.validate.field).int64.lt = 10]; + // } + // + // ``` + Lt int64 `protobuf:"varint,2,opt,name=lt,oneof"` +} + +type Int64Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyInt64 { + // // must be less than or equal to 10 + // int64 value = 1 [(buf.validate.field).int64.lte = 10]; + // } + // + // ``` + Lte int64 `protobuf:"varint,3,opt,name=lte,oneof"` +} + +func (*Int64Rules_Lt) isInt64Rules_LessThan() {} + +func (*Int64Rules_Lte) isInt64Rules_LessThan() {} + +type isInt64Rules_GreaterThan interface { + isInt64Rules_GreaterThan() +} + +type Int64Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyInt64 { + // // must be greater than 5 [int64.gt] + // int64 value = 1 [(buf.validate.field).int64.gt = 5]; + // + // // must be greater than 5 and less than 10 [int64.gt_lt] + // int64 other_value = 2 [(buf.validate.field).int64 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [int64.gt_lt_exclusive] + // int64 another_value = 3 [(buf.validate.field).int64 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt int64 `protobuf:"varint,4,opt,name=gt,oneof"` +} + +type Int64Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyInt64 { + // // must be greater than or equal to 5 [int64.gte] + // int64 value = 1 [(buf.validate.field).int64.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [int64.gte_lt] + // int64 other_value = 2 [(buf.validate.field).int64 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [int64.gte_lt_exclusive] + // int64 another_value = 3 [(buf.validate.field).int64 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte int64 `protobuf:"varint,5,opt,name=gte,oneof"` +} + +func (*Int64Rules_Gt) isInt64Rules_GreaterThan() {} + +func (*Int64Rules_Gte) isInt64Rules_GreaterThan() {} + +// UInt32Rules describes the rules applied to `uint32` values. These +// rules may also be applied to the `google.protobuf.UInt32Value` Well-Known-Type. +type UInt32Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyUInt32 { + // // value must equal 42 + // uint32 value = 1 [(buf.validate.field).uint32.const = 42]; + // } + // + // ``` + Const *uint32 `protobuf:"varint,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *UInt32Rules_Lt + // *UInt32Rules_Lte + LessThan isUInt32Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *UInt32Rules_Gt + // *UInt32Rules_Gte + GreaterThan isUInt32Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MyUInt32 { + // // must be in list [1, 2, 3] + // uint32 value = 1 [(buf.validate.field).uint32 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []uint32 `protobuf:"varint,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MyUInt32 { + // // value must not be in list [1, 2, 3] + // uint32 value = 1 [(buf.validate.field).uint32 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []uint32 `protobuf:"varint,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyUInt32 { + // uint32 value = 1 [ + // (buf.validate.field).uint32.example = 1, + // (buf.validate.field).uint32.example = 10 + // ]; + // } + // + // ``` + Example []uint32 `protobuf:"varint,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UInt32Rules) Reset() { + *x = UInt32Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *UInt32Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UInt32Rules) ProtoMessage() {} + +func (x *UInt32Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UInt32Rules.ProtoReflect.Descriptor instead. +func (*UInt32Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{10} +} + +func (x *UInt32Rules) GetConst() uint32 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *UInt32Rules) GetLessThan() isUInt32Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *UInt32Rules) GetLt() uint32 { + if x != nil { + if x, ok := x.LessThan.(*UInt32Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *UInt32Rules) GetLte() uint32 { + if x != nil { + if x, ok := x.LessThan.(*UInt32Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *UInt32Rules) GetGreaterThan() isUInt32Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *UInt32Rules) GetGt() uint32 { + if x != nil { + if x, ok := x.GreaterThan.(*UInt32Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *UInt32Rules) GetGte() uint32 { + if x != nil { + if x, ok := x.GreaterThan.(*UInt32Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *UInt32Rules) GetIn() []uint32 { + if x != nil { + return x.In + } + return nil +} + +func (x *UInt32Rules) GetNotIn() []uint32 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *UInt32Rules) GetExample() []uint32 { + if x != nil { + return x.Example + } + return nil +} + +type isUInt32Rules_LessThan interface { + isUInt32Rules_LessThan() +} + +type UInt32Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MyUInt32 { + // // must be less than 10 + // uint32 value = 1 [(buf.validate.field).uint32.lt = 10]; + // } + // + // ``` + Lt uint32 `protobuf:"varint,2,opt,name=lt,oneof"` +} + +type UInt32Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyUInt32 { + // // must be less than or equal to 10 + // uint32 value = 1 [(buf.validate.field).uint32.lte = 10]; + // } + // + // ``` + Lte uint32 `protobuf:"varint,3,opt,name=lte,oneof"` +} + +func (*UInt32Rules_Lt) isUInt32Rules_LessThan() {} + +func (*UInt32Rules_Lte) isUInt32Rules_LessThan() {} + +type isUInt32Rules_GreaterThan interface { + isUInt32Rules_GreaterThan() +} + +type UInt32Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyUInt32 { + // // must be greater than 5 [uint32.gt] + // uint32 value = 1 [(buf.validate.field).uint32.gt = 5]; + // + // // must be greater than 5 and less than 10 [uint32.gt_lt] + // uint32 other_value = 2 [(buf.validate.field).uint32 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [uint32.gt_lt_exclusive] + // uint32 another_value = 3 [(buf.validate.field).uint32 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt uint32 `protobuf:"varint,4,opt,name=gt,oneof"` +} + +type UInt32Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyUInt32 { + // // must be greater than or equal to 5 [uint32.gte] + // uint32 value = 1 [(buf.validate.field).uint32.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [uint32.gte_lt] + // uint32 other_value = 2 [(buf.validate.field).uint32 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [uint32.gte_lt_exclusive] + // uint32 another_value = 3 [(buf.validate.field).uint32 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte uint32 `protobuf:"varint,5,opt,name=gte,oneof"` +} + +func (*UInt32Rules_Gt) isUInt32Rules_GreaterThan() {} + +func (*UInt32Rules_Gte) isUInt32Rules_GreaterThan() {} + +// UInt64Rules describes the rules applied to `uint64` values. These +// rules may also be applied to the `google.protobuf.UInt64Value` Well-Known-Type. +type UInt64Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyUInt64 { + // // value must equal 42 + // uint64 value = 1 [(buf.validate.field).uint64.const = 42]; + // } + // + // ``` + Const *uint64 `protobuf:"varint,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *UInt64Rules_Lt + // *UInt64Rules_Lte + LessThan isUInt64Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *UInt64Rules_Gt + // *UInt64Rules_Gte + GreaterThan isUInt64Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MyUInt64 { + // // must be in list [1, 2, 3] + // uint64 value = 1 [(buf.validate.field).uint64 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []uint64 `protobuf:"varint,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MyUInt64 { + // // value must not be in list [1, 2, 3] + // uint64 value = 1 [(buf.validate.field).uint64 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []uint64 `protobuf:"varint,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyUInt64 { + // uint64 value = 1 [ + // (buf.validate.field).uint64.example = 1, + // (buf.validate.field).uint64.example = -10 + // ]; + // } + // + // ``` + Example []uint64 `protobuf:"varint,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UInt64Rules) Reset() { + *x = UInt64Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *UInt64Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UInt64Rules) ProtoMessage() {} + +func (x *UInt64Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UInt64Rules.ProtoReflect.Descriptor instead. +func (*UInt64Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{11} +} + +func (x *UInt64Rules) GetConst() uint64 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *UInt64Rules) GetLessThan() isUInt64Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *UInt64Rules) GetLt() uint64 { + if x != nil { + if x, ok := x.LessThan.(*UInt64Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *UInt64Rules) GetLte() uint64 { + if x != nil { + if x, ok := x.LessThan.(*UInt64Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *UInt64Rules) GetGreaterThan() isUInt64Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *UInt64Rules) GetGt() uint64 { + if x != nil { + if x, ok := x.GreaterThan.(*UInt64Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *UInt64Rules) GetGte() uint64 { + if x != nil { + if x, ok := x.GreaterThan.(*UInt64Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *UInt64Rules) GetIn() []uint64 { + if x != nil { + return x.In + } + return nil +} + +func (x *UInt64Rules) GetNotIn() []uint64 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *UInt64Rules) GetExample() []uint64 { + if x != nil { + return x.Example + } + return nil +} + +type isUInt64Rules_LessThan interface { + isUInt64Rules_LessThan() +} + +type UInt64Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MyUInt64 { + // // must be less than 10 + // uint64 value = 1 [(buf.validate.field).uint64.lt = 10]; + // } + // + // ``` + Lt uint64 `protobuf:"varint,2,opt,name=lt,oneof"` +} + +type UInt64Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyUInt64 { + // // must be less than or equal to 10 + // uint64 value = 1 [(buf.validate.field).uint64.lte = 10]; + // } + // + // ``` + Lte uint64 `protobuf:"varint,3,opt,name=lte,oneof"` +} + +func (*UInt64Rules_Lt) isUInt64Rules_LessThan() {} + +func (*UInt64Rules_Lte) isUInt64Rules_LessThan() {} + +type isUInt64Rules_GreaterThan interface { + isUInt64Rules_GreaterThan() +} + +type UInt64Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyUInt64 { + // // must be greater than 5 [uint64.gt] + // uint64 value = 1 [(buf.validate.field).uint64.gt = 5]; + // + // // must be greater than 5 and less than 10 [uint64.gt_lt] + // uint64 other_value = 2 [(buf.validate.field).uint64 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [uint64.gt_lt_exclusive] + // uint64 another_value = 3 [(buf.validate.field).uint64 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt uint64 `protobuf:"varint,4,opt,name=gt,oneof"` +} + +type UInt64Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyUInt64 { + // // must be greater than or equal to 5 [uint64.gte] + // uint64 value = 1 [(buf.validate.field).uint64.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [uint64.gte_lt] + // uint64 other_value = 2 [(buf.validate.field).uint64 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [uint64.gte_lt_exclusive] + // uint64 another_value = 3 [(buf.validate.field).uint64 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte uint64 `protobuf:"varint,5,opt,name=gte,oneof"` +} + +func (*UInt64Rules_Gt) isUInt64Rules_GreaterThan() {} + +func (*UInt64Rules_Gte) isUInt64Rules_GreaterThan() {} + +// SInt32Rules describes the rules applied to `sint32` values. +type SInt32Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MySInt32 { + // // value must equal 42 + // sint32 value = 1 [(buf.validate.field).sint32.const = 42]; + // } + // + // ``` + Const *int32 `protobuf:"zigzag32,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *SInt32Rules_Lt + // *SInt32Rules_Lte + LessThan isSInt32Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *SInt32Rules_Gt + // *SInt32Rules_Gte + GreaterThan isSInt32Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MySInt32 { + // // must be in list [1, 2, 3] + // sint32 value = 1 [(buf.validate.field).sint32 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []int32 `protobuf:"zigzag32,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MySInt32 { + // // value must not be in list [1, 2, 3] + // sint32 value = 1 [(buf.validate.field).sint32 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []int32 `protobuf:"zigzag32,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MySInt32 { + // sint32 value = 1 [ + // (buf.validate.field).sint32.example = 1, + // (buf.validate.field).sint32.example = -10 + // ]; + // } + // + // ``` + Example []int32 `protobuf:"zigzag32,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SInt32Rules) Reset() { + *x = SInt32Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SInt32Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SInt32Rules) ProtoMessage() {} + +func (x *SInt32Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SInt32Rules.ProtoReflect.Descriptor instead. +func (*SInt32Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{12} +} + +func (x *SInt32Rules) GetConst() int32 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *SInt32Rules) GetLessThan() isSInt32Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *SInt32Rules) GetLt() int32 { + if x != nil { + if x, ok := x.LessThan.(*SInt32Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *SInt32Rules) GetLte() int32 { + if x != nil { + if x, ok := x.LessThan.(*SInt32Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *SInt32Rules) GetGreaterThan() isSInt32Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *SInt32Rules) GetGt() int32 { + if x != nil { + if x, ok := x.GreaterThan.(*SInt32Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *SInt32Rules) GetGte() int32 { + if x != nil { + if x, ok := x.GreaterThan.(*SInt32Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *SInt32Rules) GetIn() []int32 { + if x != nil { + return x.In + } + return nil +} + +func (x *SInt32Rules) GetNotIn() []int32 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *SInt32Rules) GetExample() []int32 { + if x != nil { + return x.Example + } + return nil +} + +type isSInt32Rules_LessThan interface { + isSInt32Rules_LessThan() +} + +type SInt32Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field + // < value). If the field value is equal to or greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MySInt32 { + // // must be less than 10 + // sint32 value = 1 [(buf.validate.field).sint32.lt = 10]; + // } + // + // ``` + Lt int32 `protobuf:"zigzag32,2,opt,name=lt,oneof"` +} + +type SInt32Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MySInt32 { + // // must be less than or equal to 10 + // sint32 value = 1 [(buf.validate.field).sint32.lte = 10]; + // } + // + // ``` + Lte int32 `protobuf:"zigzag32,3,opt,name=lte,oneof"` +} + +func (*SInt32Rules_Lt) isSInt32Rules_LessThan() {} + +func (*SInt32Rules_Lte) isSInt32Rules_LessThan() {} + +type isSInt32Rules_GreaterThan interface { + isSInt32Rules_GreaterThan() +} + +type SInt32Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MySInt32 { + // // must be greater than 5 [sint32.gt] + // sint32 value = 1 [(buf.validate.field).sint32.gt = 5]; + // + // // must be greater than 5 and less than 10 [sint32.gt_lt] + // sint32 other_value = 2 [(buf.validate.field).sint32 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [sint32.gt_lt_exclusive] + // sint32 another_value = 3 [(buf.validate.field).sint32 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt int32 `protobuf:"zigzag32,4,opt,name=gt,oneof"` +} + +type SInt32Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MySInt32 { + // // must be greater than or equal to 5 [sint32.gte] + // sint32 value = 1 [(buf.validate.field).sint32.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [sint32.gte_lt] + // sint32 other_value = 2 [(buf.validate.field).sint32 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [sint32.gte_lt_exclusive] + // sint32 another_value = 3 [(buf.validate.field).sint32 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte int32 `protobuf:"zigzag32,5,opt,name=gte,oneof"` +} + +func (*SInt32Rules_Gt) isSInt32Rules_GreaterThan() {} + +func (*SInt32Rules_Gte) isSInt32Rules_GreaterThan() {} + +// SInt64Rules describes the rules applied to `sint64` values. +type SInt64Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MySInt64 { + // // value must equal 42 + // sint64 value = 1 [(buf.validate.field).sint64.const = 42]; + // } + // + // ``` + Const *int64 `protobuf:"zigzag64,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *SInt64Rules_Lt + // *SInt64Rules_Lte + LessThan isSInt64Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *SInt64Rules_Gt + // *SInt64Rules_Gte + GreaterThan isSInt64Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message + // is generated. + // + // ```proto + // + // message MySInt64 { + // // must be in list [1, 2, 3] + // sint64 value = 1 [(buf.validate.field).sint64 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []int64 `protobuf:"zigzag64,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MySInt64 { + // // value must not be in list [1, 2, 3] + // sint64 value = 1 [(buf.validate.field).sint64 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []int64 `protobuf:"zigzag64,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MySInt64 { + // sint64 value = 1 [ + // (buf.validate.field).sint64.example = 1, + // (buf.validate.field).sint64.example = -10 + // ]; + // } + // + // ``` + Example []int64 `protobuf:"zigzag64,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SInt64Rules) Reset() { + *x = SInt64Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SInt64Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SInt64Rules) ProtoMessage() {} + +func (x *SInt64Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SInt64Rules.ProtoReflect.Descriptor instead. +func (*SInt64Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{13} +} + +func (x *SInt64Rules) GetConst() int64 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *SInt64Rules) GetLessThan() isSInt64Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *SInt64Rules) GetLt() int64 { + if x != nil { + if x, ok := x.LessThan.(*SInt64Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *SInt64Rules) GetLte() int64 { + if x != nil { + if x, ok := x.LessThan.(*SInt64Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *SInt64Rules) GetGreaterThan() isSInt64Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *SInt64Rules) GetGt() int64 { + if x != nil { + if x, ok := x.GreaterThan.(*SInt64Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *SInt64Rules) GetGte() int64 { + if x != nil { + if x, ok := x.GreaterThan.(*SInt64Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *SInt64Rules) GetIn() []int64 { + if x != nil { + return x.In + } + return nil +} + +func (x *SInt64Rules) GetNotIn() []int64 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *SInt64Rules) GetExample() []int64 { + if x != nil { + return x.Example + } + return nil +} + +type isSInt64Rules_LessThan interface { + isSInt64Rules_LessThan() +} + +type SInt64Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field + // < value). If the field value is equal to or greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MySInt64 { + // // must be less than 10 + // sint64 value = 1 [(buf.validate.field).sint64.lt = 10]; + // } + // + // ``` + Lt int64 `protobuf:"zigzag64,2,opt,name=lt,oneof"` +} + +type SInt64Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MySInt64 { + // // must be less than or equal to 10 + // sint64 value = 1 [(buf.validate.field).sint64.lte = 10]; + // } + // + // ``` + Lte int64 `protobuf:"zigzag64,3,opt,name=lte,oneof"` +} + +func (*SInt64Rules_Lt) isSInt64Rules_LessThan() {} + +func (*SInt64Rules_Lte) isSInt64Rules_LessThan() {} + +type isSInt64Rules_GreaterThan interface { + isSInt64Rules_GreaterThan() +} + +type SInt64Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MySInt64 { + // // must be greater than 5 [sint64.gt] + // sint64 value = 1 [(buf.validate.field).sint64.gt = 5]; + // + // // must be greater than 5 and less than 10 [sint64.gt_lt] + // sint64 other_value = 2 [(buf.validate.field).sint64 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [sint64.gt_lt_exclusive] + // sint64 another_value = 3 [(buf.validate.field).sint64 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt int64 `protobuf:"zigzag64,4,opt,name=gt,oneof"` +} + +type SInt64Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MySInt64 { + // // must be greater than or equal to 5 [sint64.gte] + // sint64 value = 1 [(buf.validate.field).sint64.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [sint64.gte_lt] + // sint64 other_value = 2 [(buf.validate.field).sint64 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [sint64.gte_lt_exclusive] + // sint64 another_value = 3 [(buf.validate.field).sint64 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte int64 `protobuf:"zigzag64,5,opt,name=gte,oneof"` +} + +func (*SInt64Rules_Gt) isSInt64Rules_GreaterThan() {} + +func (*SInt64Rules_Gte) isSInt64Rules_GreaterThan() {} + +// Fixed32Rules describes the rules applied to `fixed32` values. +type Fixed32Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. + // If the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyFixed32 { + // // value must equal 42 + // fixed32 value = 1 [(buf.validate.field).fixed32.const = 42]; + // } + // + // ``` + Const *uint32 `protobuf:"fixed32,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *Fixed32Rules_Lt + // *Fixed32Rules_Lte + LessThan isFixed32Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *Fixed32Rules_Gt + // *Fixed32Rules_Gte + GreaterThan isFixed32Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message + // is generated. + // + // ```proto + // + // message MyFixed32 { + // // must be in list [1, 2, 3] + // fixed32 value = 1 [(buf.validate.field).fixed32 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []uint32 `protobuf:"fixed32,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MyFixed32 { + // // value must not be in list [1, 2, 3] + // fixed32 value = 1 [(buf.validate.field).fixed32 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []uint32 `protobuf:"fixed32,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyFixed32 { + // fixed32 value = 1 [ + // (buf.validate.field).fixed32.example = 1, + // (buf.validate.field).fixed32.example = 2 + // ]; + // } + // + // ``` + Example []uint32 `protobuf:"fixed32,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Fixed32Rules) Reset() { + *x = Fixed32Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Fixed32Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Fixed32Rules) ProtoMessage() {} + +func (x *Fixed32Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Fixed32Rules.ProtoReflect.Descriptor instead. +func (*Fixed32Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{14} +} + +func (x *Fixed32Rules) GetConst() uint32 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *Fixed32Rules) GetLessThan() isFixed32Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *Fixed32Rules) GetLt() uint32 { + if x != nil { + if x, ok := x.LessThan.(*Fixed32Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *Fixed32Rules) GetLte() uint32 { + if x != nil { + if x, ok := x.LessThan.(*Fixed32Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *Fixed32Rules) GetGreaterThan() isFixed32Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *Fixed32Rules) GetGt() uint32 { + if x != nil { + if x, ok := x.GreaterThan.(*Fixed32Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *Fixed32Rules) GetGte() uint32 { + if x != nil { + if x, ok := x.GreaterThan.(*Fixed32Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *Fixed32Rules) GetIn() []uint32 { + if x != nil { + return x.In + } + return nil +} + +func (x *Fixed32Rules) GetNotIn() []uint32 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *Fixed32Rules) GetExample() []uint32 { + if x != nil { + return x.Example + } + return nil +} + +type isFixed32Rules_LessThan interface { + isFixed32Rules_LessThan() +} + +type Fixed32Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MyFixed32 { + // // must be less than 10 + // fixed32 value = 1 [(buf.validate.field).fixed32.lt = 10]; + // } + // + // ``` + Lt uint32 `protobuf:"fixed32,2,opt,name=lt,oneof"` +} + +type Fixed32Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyFixed32 { + // // must be less than or equal to 10 + // fixed32 value = 1 [(buf.validate.field).fixed32.lte = 10]; + // } + // + // ``` + Lte uint32 `protobuf:"fixed32,3,opt,name=lte,oneof"` +} + +func (*Fixed32Rules_Lt) isFixed32Rules_LessThan() {} + +func (*Fixed32Rules_Lte) isFixed32Rules_LessThan() {} + +type isFixed32Rules_GreaterThan interface { + isFixed32Rules_GreaterThan() +} + +type Fixed32Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyFixed32 { + // // must be greater than 5 [fixed32.gt] + // fixed32 value = 1 [(buf.validate.field).fixed32.gt = 5]; + // + // // must be greater than 5 and less than 10 [fixed32.gt_lt] + // fixed32 other_value = 2 [(buf.validate.field).fixed32 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [fixed32.gt_lt_exclusive] + // fixed32 another_value = 3 [(buf.validate.field).fixed32 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt uint32 `protobuf:"fixed32,4,opt,name=gt,oneof"` +} + +type Fixed32Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyFixed32 { + // // must be greater than or equal to 5 [fixed32.gte] + // fixed32 value = 1 [(buf.validate.field).fixed32.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [fixed32.gte_lt] + // fixed32 other_value = 2 [(buf.validate.field).fixed32 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [fixed32.gte_lt_exclusive] + // fixed32 another_value = 3 [(buf.validate.field).fixed32 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte uint32 `protobuf:"fixed32,5,opt,name=gte,oneof"` +} + +func (*Fixed32Rules_Gt) isFixed32Rules_GreaterThan() {} + +func (*Fixed32Rules_Gte) isFixed32Rules_GreaterThan() {} + +// Fixed64Rules describes the rules applied to `fixed64` values. +type Fixed64Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyFixed64 { + // // value must equal 42 + // fixed64 value = 1 [(buf.validate.field).fixed64.const = 42]; + // } + // + // ``` + Const *uint64 `protobuf:"fixed64,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *Fixed64Rules_Lt + // *Fixed64Rules_Lte + LessThan isFixed64Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *Fixed64Rules_Gt + // *Fixed64Rules_Gte + GreaterThan isFixed64Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MyFixed64 { + // // must be in list [1, 2, 3] + // fixed64 value = 1 [(buf.validate.field).fixed64 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []uint64 `protobuf:"fixed64,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MyFixed64 { + // // value must not be in list [1, 2, 3] + // fixed64 value = 1 [(buf.validate.field).fixed64 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []uint64 `protobuf:"fixed64,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyFixed64 { + // fixed64 value = 1 [ + // (buf.validate.field).fixed64.example = 1, + // (buf.validate.field).fixed64.example = 2 + // ]; + // } + // + // ``` + Example []uint64 `protobuf:"fixed64,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Fixed64Rules) Reset() { + *x = Fixed64Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Fixed64Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Fixed64Rules) ProtoMessage() {} + +func (x *Fixed64Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Fixed64Rules.ProtoReflect.Descriptor instead. +func (*Fixed64Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{15} +} + +func (x *Fixed64Rules) GetConst() uint64 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *Fixed64Rules) GetLessThan() isFixed64Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *Fixed64Rules) GetLt() uint64 { + if x != nil { + if x, ok := x.LessThan.(*Fixed64Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *Fixed64Rules) GetLte() uint64 { + if x != nil { + if x, ok := x.LessThan.(*Fixed64Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *Fixed64Rules) GetGreaterThan() isFixed64Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *Fixed64Rules) GetGt() uint64 { + if x != nil { + if x, ok := x.GreaterThan.(*Fixed64Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *Fixed64Rules) GetGte() uint64 { + if x != nil { + if x, ok := x.GreaterThan.(*Fixed64Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *Fixed64Rules) GetIn() []uint64 { + if x != nil { + return x.In + } + return nil +} + +func (x *Fixed64Rules) GetNotIn() []uint64 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *Fixed64Rules) GetExample() []uint64 { + if x != nil { + return x.Example + } + return nil +} + +type isFixed64Rules_LessThan interface { + isFixed64Rules_LessThan() +} + +type Fixed64Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MyFixed64 { + // // must be less than 10 + // fixed64 value = 1 [(buf.validate.field).fixed64.lt = 10]; + // } + // + // ``` + Lt uint64 `protobuf:"fixed64,2,opt,name=lt,oneof"` +} + +type Fixed64Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MyFixed64 { + // // must be less than or equal to 10 + // fixed64 value = 1 [(buf.validate.field).fixed64.lte = 10]; + // } + // + // ``` + Lte uint64 `protobuf:"fixed64,3,opt,name=lte,oneof"` +} + +func (*Fixed64Rules_Lt) isFixed64Rules_LessThan() {} + +func (*Fixed64Rules_Lte) isFixed64Rules_LessThan() {} + +type isFixed64Rules_GreaterThan interface { + isFixed64Rules_GreaterThan() +} + +type Fixed64Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyFixed64 { + // // must be greater than 5 [fixed64.gt] + // fixed64 value = 1 [(buf.validate.field).fixed64.gt = 5]; + // + // // must be greater than 5 and less than 10 [fixed64.gt_lt] + // fixed64 other_value = 2 [(buf.validate.field).fixed64 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [fixed64.gt_lt_exclusive] + // fixed64 another_value = 3 [(buf.validate.field).fixed64 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt uint64 `protobuf:"fixed64,4,opt,name=gt,oneof"` +} + +type Fixed64Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyFixed64 { + // // must be greater than or equal to 5 [fixed64.gte] + // fixed64 value = 1 [(buf.validate.field).fixed64.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [fixed64.gte_lt] + // fixed64 other_value = 2 [(buf.validate.field).fixed64 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [fixed64.gte_lt_exclusive] + // fixed64 another_value = 3 [(buf.validate.field).fixed64 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte uint64 `protobuf:"fixed64,5,opt,name=gte,oneof"` +} + +func (*Fixed64Rules_Gt) isFixed64Rules_GreaterThan() {} + +func (*Fixed64Rules_Gte) isFixed64Rules_GreaterThan() {} + +// SFixed32Rules describes the rules applied to `fixed32` values. +type SFixed32Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MySFixed32 { + // // value must equal 42 + // sfixed32 value = 1 [(buf.validate.field).sfixed32.const = 42]; + // } + // + // ``` + Const *int32 `protobuf:"fixed32,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *SFixed32Rules_Lt + // *SFixed32Rules_Lte + LessThan isSFixed32Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *SFixed32Rules_Gt + // *SFixed32Rules_Gte + GreaterThan isSFixed32Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MySFixed32 { + // // must be in list [1, 2, 3] + // sfixed32 value = 1 [(buf.validate.field).sfixed32 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []int32 `protobuf:"fixed32,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MySFixed32 { + // // value must not be in list [1, 2, 3] + // sfixed32 value = 1 [(buf.validate.field).sfixed32 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []int32 `protobuf:"fixed32,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MySFixed32 { + // sfixed32 value = 1 [ + // (buf.validate.field).sfixed32.example = 1, + // (buf.validate.field).sfixed32.example = 2 + // ]; + // } + // + // ``` + Example []int32 `protobuf:"fixed32,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SFixed32Rules) Reset() { + *x = SFixed32Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SFixed32Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SFixed32Rules) ProtoMessage() {} + +func (x *SFixed32Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[16] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SFixed32Rules.ProtoReflect.Descriptor instead. +func (*SFixed32Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{16} +} + +func (x *SFixed32Rules) GetConst() int32 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *SFixed32Rules) GetLessThan() isSFixed32Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *SFixed32Rules) GetLt() int32 { + if x != nil { + if x, ok := x.LessThan.(*SFixed32Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *SFixed32Rules) GetLte() int32 { + if x != nil { + if x, ok := x.LessThan.(*SFixed32Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *SFixed32Rules) GetGreaterThan() isSFixed32Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *SFixed32Rules) GetGt() int32 { + if x != nil { + if x, ok := x.GreaterThan.(*SFixed32Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *SFixed32Rules) GetGte() int32 { + if x != nil { + if x, ok := x.GreaterThan.(*SFixed32Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *SFixed32Rules) GetIn() []int32 { + if x != nil { + return x.In + } + return nil +} + +func (x *SFixed32Rules) GetNotIn() []int32 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *SFixed32Rules) GetExample() []int32 { + if x != nil { + return x.Example + } + return nil +} + +type isSFixed32Rules_LessThan interface { + isSFixed32Rules_LessThan() +} + +type SFixed32Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MySFixed32 { + // // must be less than 10 + // sfixed32 value = 1 [(buf.validate.field).sfixed32.lt = 10]; + // } + // + // ``` + Lt int32 `protobuf:"fixed32,2,opt,name=lt,oneof"` +} + +type SFixed32Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MySFixed32 { + // // must be less than or equal to 10 + // sfixed32 value = 1 [(buf.validate.field).sfixed32.lte = 10]; + // } + // + // ``` + Lte int32 `protobuf:"fixed32,3,opt,name=lte,oneof"` +} + +func (*SFixed32Rules_Lt) isSFixed32Rules_LessThan() {} + +func (*SFixed32Rules_Lte) isSFixed32Rules_LessThan() {} + +type isSFixed32Rules_GreaterThan interface { + isSFixed32Rules_GreaterThan() +} + +type SFixed32Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MySFixed32 { + // // must be greater than 5 [sfixed32.gt] + // sfixed32 value = 1 [(buf.validate.field).sfixed32.gt = 5]; + // + // // must be greater than 5 and less than 10 [sfixed32.gt_lt] + // sfixed32 other_value = 2 [(buf.validate.field).sfixed32 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [sfixed32.gt_lt_exclusive] + // sfixed32 another_value = 3 [(buf.validate.field).sfixed32 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt int32 `protobuf:"fixed32,4,opt,name=gt,oneof"` +} + +type SFixed32Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MySFixed32 { + // // must be greater than or equal to 5 [sfixed32.gte] + // sfixed32 value = 1 [(buf.validate.field).sfixed32.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [sfixed32.gte_lt] + // sfixed32 other_value = 2 [(buf.validate.field).sfixed32 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [sfixed32.gte_lt_exclusive] + // sfixed32 another_value = 3 [(buf.validate.field).sfixed32 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte int32 `protobuf:"fixed32,5,opt,name=gte,oneof"` +} + +func (*SFixed32Rules_Gt) isSFixed32Rules_GreaterThan() {} + +func (*SFixed32Rules_Gte) isSFixed32Rules_GreaterThan() {} + +// SFixed64Rules describes the rules applied to `fixed64` values. +type SFixed64Rules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MySFixed64 { + // // value must equal 42 + // sfixed64 value = 1 [(buf.validate.field).sfixed64.const = 42]; + // } + // + // ``` + Const *int64 `protobuf:"fixed64,1,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *SFixed64Rules_Lt + // *SFixed64Rules_Lte + LessThan isSFixed64Rules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *SFixed64Rules_Gt + // *SFixed64Rules_Gte + GreaterThan isSFixed64Rules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` requires the field value to be equal to one of the specified values. + // If the field value isn't one of the specified values, an error message is + // generated. + // + // ```proto + // + // message MySFixed64 { + // // must be in list [1, 2, 3] + // sfixed64 value = 1 [(buf.validate.field).sfixed64 = { in: [1, 2, 3] }]; + // } + // + // ``` + In []int64 `protobuf:"fixed64,6,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not be equal to any of the specified + // values. If the field value is one of the specified values, an error + // message is generated. + // + // ```proto + // + // message MySFixed64 { + // // value must not be in list [1, 2, 3] + // sfixed64 value = 1 [(buf.validate.field).sfixed64 = { not_in: [1, 2, 3] }]; + // } + // + // ``` + NotIn []int64 `protobuf:"fixed64,7,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MySFixed64 { + // sfixed64 value = 1 [ + // (buf.validate.field).sfixed64.example = 1, + // (buf.validate.field).sfixed64.example = 2 + // ]; + // } + // + // ``` + Example []int64 `protobuf:"fixed64,8,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SFixed64Rules) Reset() { + *x = SFixed64Rules{} + mi := &file_buf_validate_validate_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SFixed64Rules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SFixed64Rules) ProtoMessage() {} + +func (x *SFixed64Rules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[17] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SFixed64Rules.ProtoReflect.Descriptor instead. +func (*SFixed64Rules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{17} +} + +func (x *SFixed64Rules) GetConst() int64 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *SFixed64Rules) GetLessThan() isSFixed64Rules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *SFixed64Rules) GetLt() int64 { + if x != nil { + if x, ok := x.LessThan.(*SFixed64Rules_Lt); ok { + return x.Lt + } + } + return 0 +} + +func (x *SFixed64Rules) GetLte() int64 { + if x != nil { + if x, ok := x.LessThan.(*SFixed64Rules_Lte); ok { + return x.Lte + } + } + return 0 +} + +func (x *SFixed64Rules) GetGreaterThan() isSFixed64Rules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *SFixed64Rules) GetGt() int64 { + if x != nil { + if x, ok := x.GreaterThan.(*SFixed64Rules_Gt); ok { + return x.Gt + } + } + return 0 +} + +func (x *SFixed64Rules) GetGte() int64 { + if x != nil { + if x, ok := x.GreaterThan.(*SFixed64Rules_Gte); ok { + return x.Gte + } + } + return 0 +} + +func (x *SFixed64Rules) GetIn() []int64 { + if x != nil { + return x.In + } + return nil +} + +func (x *SFixed64Rules) GetNotIn() []int64 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *SFixed64Rules) GetExample() []int64 { + if x != nil { + return x.Example + } + return nil +} + +type isSFixed64Rules_LessThan interface { + isSFixed64Rules_LessThan() +} + +type SFixed64Rules_Lt struct { + // `lt` requires the field value to be less than the specified value (field < + // value). If the field value is equal to or greater than the specified value, + // an error message is generated. + // + // ```proto + // + // message MySFixed64 { + // // must be less than 10 + // sfixed64 value = 1 [(buf.validate.field).sfixed64.lt = 10]; + // } + // + // ``` + Lt int64 `protobuf:"fixed64,2,opt,name=lt,oneof"` +} + +type SFixed64Rules_Lte struct { + // `lte` requires the field value to be less than or equal to the specified + // value (field <= value). If the field value is greater than the specified + // value, an error message is generated. + // + // ```proto + // + // message MySFixed64 { + // // must be less than or equal to 10 + // sfixed64 value = 1 [(buf.validate.field).sfixed64.lte = 10]; + // } + // + // ``` + Lte int64 `protobuf:"fixed64,3,opt,name=lte,oneof"` +} + +func (*SFixed64Rules_Lt) isSFixed64Rules_LessThan() {} + +func (*SFixed64Rules_Lte) isSFixed64Rules_LessThan() {} + +type isSFixed64Rules_GreaterThan interface { + isSFixed64Rules_GreaterThan() +} + +type SFixed64Rules_Gt struct { + // `gt` requires the field value to be greater than the specified value + // (exclusive). If the value of `gt` is larger than a specified `lt` or + // `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MySFixed64 { + // // must be greater than 5 [sfixed64.gt] + // sfixed64 value = 1 [(buf.validate.field).sfixed64.gt = 5]; + // + // // must be greater than 5 and less than 10 [sfixed64.gt_lt] + // sfixed64 other_value = 2 [(buf.validate.field).sfixed64 = { gt: 5, lt: 10 }]; + // + // // must be greater than 10 or less than 5 [sfixed64.gt_lt_exclusive] + // sfixed64 another_value = 3 [(buf.validate.field).sfixed64 = { gt: 10, lt: 5 }]; + // } + // + // ``` + Gt int64 `protobuf:"fixed64,4,opt,name=gt,oneof"` +} + +type SFixed64Rules_Gte struct { + // `gte` requires the field value to be greater than or equal to the specified + // value (exclusive). If the value of `gte` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MySFixed64 { + // // must be greater than or equal to 5 [sfixed64.gte] + // sfixed64 value = 1 [(buf.validate.field).sfixed64.gte = 5]; + // + // // must be greater than or equal to 5 and less than 10 [sfixed64.gte_lt] + // sfixed64 other_value = 2 [(buf.validate.field).sfixed64 = { gte: 5, lt: 10 }]; + // + // // must be greater than or equal to 10 or less than 5 [sfixed64.gte_lt_exclusive] + // sfixed64 another_value = 3 [(buf.validate.field).sfixed64 = { gte: 10, lt: 5 }]; + // } + // + // ``` + Gte int64 `protobuf:"fixed64,5,opt,name=gte,oneof"` +} + +func (*SFixed64Rules_Gt) isSFixed64Rules_GreaterThan() {} + +func (*SFixed64Rules_Gte) isSFixed64Rules_GreaterThan() {} + +// BoolRules describes the rules applied to `bool` values. These rules +// may also be applied to the `google.protobuf.BoolValue` Well-Known-Type. +type BoolRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified boolean value. + // If the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyBool { + // // value must equal true + // bool value = 1 [(buf.validate.field).bool.const = true]; + // } + // + // ``` + Const *bool `protobuf:"varint,1,opt,name=const" json:"const,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyBool { + // bool value = 1 [ + // (buf.validate.field).bool.example = 1, + // (buf.validate.field).bool.example = 2 + // ]; + // } + // + // ``` + Example []bool `protobuf:"varint,2,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *BoolRules) Reset() { + *x = BoolRules{} + mi := &file_buf_validate_validate_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *BoolRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BoolRules) ProtoMessage() {} + +func (x *BoolRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[18] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BoolRules.ProtoReflect.Descriptor instead. +func (*BoolRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{18} +} + +func (x *BoolRules) GetConst() bool { + if x != nil && x.Const != nil { + return *x.Const + } + return false +} + +func (x *BoolRules) GetExample() []bool { + if x != nil { + return x.Example + } + return nil +} + +// StringRules describes the rules applied to `string` values These +// rules may also be applied to the `google.protobuf.StringValue` Well-Known-Type. +type StringRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified value. If + // the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyString { + // // value must equal `hello` + // string value = 1 [(buf.validate.field).string.const = "hello"]; + // } + // + // ``` + Const *string `protobuf:"bytes,1,opt,name=const" json:"const,omitempty"` + // `len` dictates that the field value must have the specified + // number of characters (Unicode code points), which may differ from the number + // of bytes in the string. If the field value does not meet the specified + // length, an error message will be generated. + // + // ```proto + // + // message MyString { + // // value length must be 5 characters + // string value = 1 [(buf.validate.field).string.len = 5]; + // } + // + // ``` + Len *uint64 `protobuf:"varint,19,opt,name=len" json:"len,omitempty"` + // `min_len` specifies that the field value must have at least the specified + // number of characters (Unicode code points), which may differ from the number + // of bytes in the string. If the field value contains fewer characters, an error + // message will be generated. + // + // ```proto + // + // message MyString { + // // value length must be at least 3 characters + // string value = 1 [(buf.validate.field).string.min_len = 3]; + // } + // + // ``` + MinLen *uint64 `protobuf:"varint,2,opt,name=min_len,json=minLen" json:"min_len,omitempty"` + // `max_len` specifies that the field value must have no more than the specified + // number of characters (Unicode code points), which may differ from the + // number of bytes in the string. If the field value contains more characters, + // an error message will be generated. + // + // ```proto + // + // message MyString { + // // value length must be at most 10 characters + // string value = 1 [(buf.validate.field).string.max_len = 10]; + // } + // + // ``` + MaxLen *uint64 `protobuf:"varint,3,opt,name=max_len,json=maxLen" json:"max_len,omitempty"` + // `len_bytes` dictates that the field value must have the specified number of + // bytes. If the field value does not match the specified length in bytes, + // an error message will be generated. + // + // ```proto + // + // message MyString { + // // value length must be 6 bytes + // string value = 1 [(buf.validate.field).string.len_bytes = 6]; + // } + // + // ``` + LenBytes *uint64 `protobuf:"varint,20,opt,name=len_bytes,json=lenBytes" json:"len_bytes,omitempty"` + // `min_bytes` specifies that the field value must have at least the specified + // number of bytes. If the field value contains fewer bytes, an error message + // will be generated. + // + // ```proto + // + // message MyString { + // // value length must be at least 4 bytes + // string value = 1 [(buf.validate.field).string.min_bytes = 4]; + // } + // + // ``` + MinBytes *uint64 `protobuf:"varint,4,opt,name=min_bytes,json=minBytes" json:"min_bytes,omitempty"` + // `max_bytes` specifies that the field value must have no more than the + // specified number of bytes. If the field value contains more bytes, an + // error message will be generated. + // + // ```proto + // + // message MyString { + // // value length must be at most 8 bytes + // string value = 1 [(buf.validate.field).string.max_bytes = 8]; + // } + // + // ``` + MaxBytes *uint64 `protobuf:"varint,5,opt,name=max_bytes,json=maxBytes" json:"max_bytes,omitempty"` + // `pattern` specifies that the field value must match the specified + // regular expression (RE2 syntax), with the expression provided without any + // delimiters. If the field value doesn't match the regular expression, an + // error message will be generated. + // + // ```proto + // + // message MyString { + // // value does not match regex pattern `^[a-zA-Z]//$` + // string value = 1 [(buf.validate.field).string.pattern = "^[a-zA-Z]//$"]; + // } + // + // ``` + Pattern *string `protobuf:"bytes,6,opt,name=pattern" json:"pattern,omitempty"` + // `prefix` specifies that the field value must have the + // specified substring at the beginning of the string. If the field value + // doesn't start with the specified prefix, an error message will be + // generated. + // + // ```proto + // + // message MyString { + // // value does not have prefix `pre` + // string value = 1 [(buf.validate.field).string.prefix = "pre"]; + // } + // + // ``` + Prefix *string `protobuf:"bytes,7,opt,name=prefix" json:"prefix,omitempty"` + // `suffix` specifies that the field value must have the + // specified substring at the end of the string. If the field value doesn't + // end with the specified suffix, an error message will be generated. + // + // ```proto + // + // message MyString { + // // value does not have suffix `post` + // string value = 1 [(buf.validate.field).string.suffix = "post"]; + // } + // + // ``` + Suffix *string `protobuf:"bytes,8,opt,name=suffix" json:"suffix,omitempty"` + // `contains` specifies that the field value must have the + // specified substring anywhere in the string. If the field value doesn't + // contain the specified substring, an error message will be generated. + // + // ```proto + // + // message MyString { + // // value does not contain substring `inside`. + // string value = 1 [(buf.validate.field).string.contains = "inside"]; + // } + // + // ``` + Contains *string `protobuf:"bytes,9,opt,name=contains" json:"contains,omitempty"` + // `not_contains` specifies that the field value must not have the + // specified substring anywhere in the string. If the field value contains + // the specified substring, an error message will be generated. + // + // ```proto + // + // message MyString { + // // value contains substring `inside`. + // string value = 1 [(buf.validate.field).string.not_contains = "inside"]; + // } + // + // ``` + NotContains *string `protobuf:"bytes,23,opt,name=not_contains,json=notContains" json:"not_contains,omitempty"` + // `in` specifies that the field value must be equal to one of the specified + // values. If the field value isn't one of the specified values, an error + // message will be generated. + // + // ```proto + // + // message MyString { + // // must be in list ["apple", "banana"] + // string value = 1 [(buf.validate.field).string.in = "apple", (buf.validate.field).string.in = "banana"]; + // } + // + // ``` + In []string `protobuf:"bytes,10,rep,name=in" json:"in,omitempty"` + // `not_in` specifies that the field value cannot be equal to any + // of the specified values. If the field value is one of the specified values, + // an error message will be generated. + // ```proto + // + // message MyString { + // // value must not be in list ["orange", "grape"] + // string value = 1 [(buf.validate.field).string.not_in = "orange", (buf.validate.field).string.not_in = "grape"]; + // } + // + // ``` + NotIn []string `protobuf:"bytes,11,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `WellKnown` rules provide advanced rules against common string + // patterns. + // + // Types that are valid to be assigned to WellKnown: + // + // *StringRules_Email + // *StringRules_Hostname + // *StringRules_Ip + // *StringRules_Ipv4 + // *StringRules_Ipv6 + // *StringRules_Uri + // *StringRules_UriRef + // *StringRules_Address + // *StringRules_Uuid + // *StringRules_Tuuid + // *StringRules_IpWithPrefixlen + // *StringRules_Ipv4WithPrefixlen + // *StringRules_Ipv6WithPrefixlen + // *StringRules_IpPrefix + // *StringRules_Ipv4Prefix + // *StringRules_Ipv6Prefix + // *StringRules_HostAndPort + // *StringRules_Ulid + // *StringRules_ProtobufFqn + // *StringRules_ProtobufDotFqn + // *StringRules_WellKnownRegex + WellKnown isStringRules_WellKnown `protobuf_oneof:"well_known"` + // This applies to regexes `HTTP_HEADER_NAME` and `HTTP_HEADER_VALUE` to + // enable strict header validation. By default, this is true, and HTTP header + // validations are [RFC-compliant](https://datatracker.ietf.org/doc/html/rfc7230#section-3). Setting to false will enable looser + // validations that only disallow `\r\n\0` characters, which can be used to + // bypass header matching rules. + // + // ```proto + // + // message MyString { + // // The field `value` must have be a valid HTTP headers, but not enforced with strict rules. + // string value = 1 [(buf.validate.field).string.strict = false]; + // } + // + // ``` + Strict *bool `protobuf:"varint,25,opt,name=strict" json:"strict,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyString { + // string value = 1 [ + // (buf.validate.field).string.example = "hello", + // (buf.validate.field).string.example = "world" + // ]; + // } + // + // ``` + Example []string `protobuf:"bytes,34,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StringRules) Reset() { + *x = StringRules{} + mi := &file_buf_validate_validate_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StringRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StringRules) ProtoMessage() {} + +func (x *StringRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[19] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StringRules.ProtoReflect.Descriptor instead. +func (*StringRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{19} +} + +func (x *StringRules) GetConst() string { + if x != nil && x.Const != nil { + return *x.Const + } + return "" +} + +func (x *StringRules) GetLen() uint64 { + if x != nil && x.Len != nil { + return *x.Len + } + return 0 +} + +func (x *StringRules) GetMinLen() uint64 { + if x != nil && x.MinLen != nil { + return *x.MinLen + } + return 0 +} + +func (x *StringRules) GetMaxLen() uint64 { + if x != nil && x.MaxLen != nil { + return *x.MaxLen + } + return 0 +} + +func (x *StringRules) GetLenBytes() uint64 { + if x != nil && x.LenBytes != nil { + return *x.LenBytes + } + return 0 +} + +func (x *StringRules) GetMinBytes() uint64 { + if x != nil && x.MinBytes != nil { + return *x.MinBytes + } + return 0 +} + +func (x *StringRules) GetMaxBytes() uint64 { + if x != nil && x.MaxBytes != nil { + return *x.MaxBytes + } + return 0 +} + +func (x *StringRules) GetPattern() string { + if x != nil && x.Pattern != nil { + return *x.Pattern + } + return "" +} + +func (x *StringRules) GetPrefix() string { + if x != nil && x.Prefix != nil { + return *x.Prefix + } + return "" +} + +func (x *StringRules) GetSuffix() string { + if x != nil && x.Suffix != nil { + return *x.Suffix + } + return "" +} + +func (x *StringRules) GetContains() string { + if x != nil && x.Contains != nil { + return *x.Contains + } + return "" +} + +func (x *StringRules) GetNotContains() string { + if x != nil && x.NotContains != nil { + return *x.NotContains + } + return "" +} + +func (x *StringRules) GetIn() []string { + if x != nil { + return x.In + } + return nil +} + +func (x *StringRules) GetNotIn() []string { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *StringRules) GetWellKnown() isStringRules_WellKnown { + if x != nil { + return x.WellKnown + } + return nil +} + +func (x *StringRules) GetEmail() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Email); ok { + return x.Email + } + } + return false +} + +func (x *StringRules) GetHostname() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Hostname); ok { + return x.Hostname + } + } + return false +} + +func (x *StringRules) GetIp() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Ip); ok { + return x.Ip + } + } + return false +} + +func (x *StringRules) GetIpv4() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Ipv4); ok { + return x.Ipv4 + } + } + return false +} + +func (x *StringRules) GetIpv6() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Ipv6); ok { + return x.Ipv6 + } + } + return false +} + +func (x *StringRules) GetUri() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Uri); ok { + return x.Uri + } + } + return false +} + +func (x *StringRules) GetUriRef() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_UriRef); ok { + return x.UriRef + } + } + return false +} + +func (x *StringRules) GetAddress() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Address); ok { + return x.Address + } + } + return false +} + +func (x *StringRules) GetUuid() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Uuid); ok { + return x.Uuid + } + } + return false +} + +func (x *StringRules) GetTuuid() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Tuuid); ok { + return x.Tuuid + } + } + return false +} + +func (x *StringRules) GetIpWithPrefixlen() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_IpWithPrefixlen); ok { + return x.IpWithPrefixlen + } + } + return false +} + +func (x *StringRules) GetIpv4WithPrefixlen() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Ipv4WithPrefixlen); ok { + return x.Ipv4WithPrefixlen + } + } + return false +} + +func (x *StringRules) GetIpv6WithPrefixlen() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Ipv6WithPrefixlen); ok { + return x.Ipv6WithPrefixlen + } + } + return false +} + +func (x *StringRules) GetIpPrefix() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_IpPrefix); ok { + return x.IpPrefix + } + } + return false +} + +func (x *StringRules) GetIpv4Prefix() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Ipv4Prefix); ok { + return x.Ipv4Prefix + } + } + return false +} + +func (x *StringRules) GetIpv6Prefix() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Ipv6Prefix); ok { + return x.Ipv6Prefix + } + } + return false +} + +func (x *StringRules) GetHostAndPort() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_HostAndPort); ok { + return x.HostAndPort + } + } + return false +} + +func (x *StringRules) GetUlid() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_Ulid); ok { + return x.Ulid + } + } + return false +} + +func (x *StringRules) GetProtobufFqn() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_ProtobufFqn); ok { + return x.ProtobufFqn + } + } + return false +} + +func (x *StringRules) GetProtobufDotFqn() bool { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_ProtobufDotFqn); ok { + return x.ProtobufDotFqn + } + } + return false +} + +func (x *StringRules) GetWellKnownRegex() KnownRegex { + if x != nil { + if x, ok := x.WellKnown.(*StringRules_WellKnownRegex); ok { + return x.WellKnownRegex + } + } + return KnownRegex_KNOWN_REGEX_UNSPECIFIED +} + +func (x *StringRules) GetStrict() bool { + if x != nil && x.Strict != nil { + return *x.Strict + } + return false +} + +func (x *StringRules) GetExample() []string { + if x != nil { + return x.Example + } + return nil +} + +type isStringRules_WellKnown interface { + isStringRules_WellKnown() +} + +type StringRules_Email struct { + // `email` specifies that the field value must be a valid email address, for + // example "foo@example.com". + // + // Conforms to the definition for a valid email address from the [HTML standard](https://html.spec.whatwg.org/multipage/input.html#valid-e-mail-address). + // Note that this standard willfully deviates from [RFC 5322](https://datatracker.ietf.org/doc/html/rfc5322), + // which allows many unexpected forms of email addresses and will easily match + // a typographical error. + // + // If the field value isn't a valid email address, an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid email address + // string value = 1 [(buf.validate.field).string.email = true]; + // } + // + // ``` + Email bool `protobuf:"varint,12,opt,name=email,oneof"` +} + +type StringRules_Hostname struct { + // `hostname` specifies that the field value must be a valid hostname, for + // example "foo.example.com". + // + // A valid hostname follows the rules below: + // - The name consists of one or more labels, separated by a dot ("."). + // - Each label can be 1 to 63 alphanumeric characters. + // - A label can contain hyphens ("-"), but must not start or end with a hyphen. + // - The right-most label must not be digits only. + // - The name can have a trailing dot—for example, "foo.example.com.". + // - The name can be 253 characters at most, excluding the optional trailing dot. + // + // If the field value isn't a valid hostname, an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid hostname + // string value = 1 [(buf.validate.field).string.hostname = true]; + // } + // + // ``` + Hostname bool `protobuf:"varint,13,opt,name=hostname,oneof"` +} + +type StringRules_Ip struct { + // `ip` specifies that the field value must be a valid IP (v4 or v6) address. + // + // IPv4 addresses are expected in the dotted decimal format—for example, "192.168.5.21". + // IPv6 addresses are expected in their text representation—for example, "::1", + // or "2001:0DB8:ABCD:0012::0". + // + // Both formats are well-defined in the internet standard [RFC 3986](https://datatracker.ietf.org/doc/html/rfc3986). + // Zone identifiers for IPv6 addresses (for example, "fe80::a%en1") are supported. + // + // If the field value isn't a valid IP address, an error message will be + // generated. + // + // ```proto + // + // message MyString { + // // must be a valid IP address + // string value = 1 [(buf.validate.field).string.ip = true]; + // } + // + // ``` + Ip bool `protobuf:"varint,14,opt,name=ip,oneof"` +} + +type StringRules_Ipv4 struct { + // `ipv4` specifies that the field value must be a valid IPv4 address—for + // example "192.168.5.21". If the field value isn't a valid IPv4 address, an + // error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid IPv4 address + // string value = 1 [(buf.validate.field).string.ipv4 = true]; + // } + // + // ``` + Ipv4 bool `protobuf:"varint,15,opt,name=ipv4,oneof"` +} + +type StringRules_Ipv6 struct { + // `ipv6` specifies that the field value must be a valid IPv6 address—for + // example "::1", or "d7a:115c:a1e0:ab12:4843:cd96:626b:430b". If the field + // value is not a valid IPv6 address, an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid IPv6 address + // string value = 1 [(buf.validate.field).string.ipv6 = true]; + // } + // + // ``` + Ipv6 bool `protobuf:"varint,16,opt,name=ipv6,oneof"` +} + +type StringRules_Uri struct { + // `uri` specifies that the field value must be a valid URI, for example + // "https://example.com/foo/bar?baz=quux#frag". + // + // URI is defined in the internet standard [RFC 3986](https://datatracker.ietf.org/doc/html/rfc3986). + // Zone Identifiers in IPv6 address literals are supported ([RFC 6874](https://datatracker.ietf.org/doc/html/rfc6874)). + // + // If the field value isn't a valid URI, an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid URI + // string value = 1 [(buf.validate.field).string.uri = true]; + // } + // + // ``` + Uri bool `protobuf:"varint,17,opt,name=uri,oneof"` +} + +type StringRules_UriRef struct { + // `uri_ref` specifies that the field value must be a valid URI Reference—either + // a URI such as "https://example.com/foo/bar?baz=quux#frag", or a Relative + // Reference such as "./foo/bar?query". + // + // URI, URI Reference, and Relative Reference are defined in the internet + // standard [RFC 3986](https://datatracker.ietf.org/doc/html/rfc3986). Zone + // Identifiers in IPv6 address literals are supported ([RFC 6874](https://datatracker.ietf.org/doc/html/rfc6874)). + // + // If the field value isn't a valid URI Reference, an error message will be + // generated. + // + // ```proto + // + // message MyString { + // // must be a valid URI Reference + // string value = 1 [(buf.validate.field).string.uri_ref = true]; + // } + // + // ``` + UriRef bool `protobuf:"varint,18,opt,name=uri_ref,json=uriRef,oneof"` +} + +type StringRules_Address struct { + // `address` specifies that the field value must be either a valid hostname + // (for example, "example.com"), or a valid IP (v4 or v6) address (for example, + // "192.168.0.1", or "::1"). If the field value isn't a valid hostname or IP, + // an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid hostname, or ip address + // string value = 1 [(buf.validate.field).string.address = true]; + // } + // + // ``` + Address bool `protobuf:"varint,21,opt,name=address,oneof"` +} + +type StringRules_Uuid struct { + // `uuid` specifies that the field value must be a valid UUID as defined by + // [RFC 4122](https://datatracker.ietf.org/doc/html/rfc4122#section-4.1.2). If the + // field value isn't a valid UUID, an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid UUID + // string value = 1 [(buf.validate.field).string.uuid = true]; + // } + // + // ``` + Uuid bool `protobuf:"varint,22,opt,name=uuid,oneof"` +} + +type StringRules_Tuuid struct { + // `tuuid` (trimmed UUID) specifies that the field value must be a valid UUID as + // defined by [RFC 4122](https://datatracker.ietf.org/doc/html/rfc4122#section-4.1.2) with all dashes + // omitted. If the field value isn't a valid UUID without dashes, an error message + // will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid trimmed UUID + // string value = 1 [(buf.validate.field).string.tuuid = true]; + // } + // + // ``` + Tuuid bool `protobuf:"varint,33,opt,name=tuuid,oneof"` +} + +type StringRules_IpWithPrefixlen struct { + // `ip_with_prefixlen` specifies that the field value must be a valid IP + // (v4 or v6) address with prefix length—for example, "192.168.5.21/16" or + // "2001:0DB8:ABCD:0012::F1/64". If the field value isn't a valid IP with + // prefix length, an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid IP with prefix length + // string value = 1 [(buf.validate.field).string.ip_with_prefixlen = true]; + // } + // + // ``` + IpWithPrefixlen bool `protobuf:"varint,26,opt,name=ip_with_prefixlen,json=ipWithPrefixlen,oneof"` +} + +type StringRules_Ipv4WithPrefixlen struct { + // `ipv4_with_prefixlen` specifies that the field value must be a valid + // IPv4 address with prefix length—for example, "192.168.5.21/16". If the + // field value isn't a valid IPv4 address with prefix length, an error + // message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid IPv4 address with prefix length + // string value = 1 [(buf.validate.field).string.ipv4_with_prefixlen = true]; + // } + // + // ``` + Ipv4WithPrefixlen bool `protobuf:"varint,27,opt,name=ipv4_with_prefixlen,json=ipv4WithPrefixlen,oneof"` +} + +type StringRules_Ipv6WithPrefixlen struct { + // `ipv6_with_prefixlen` specifies that the field value must be a valid + // IPv6 address with prefix length—for example, "2001:0DB8:ABCD:0012::F1/64". + // If the field value is not a valid IPv6 address with prefix length, + // an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid IPv6 address prefix length + // string value = 1 [(buf.validate.field).string.ipv6_with_prefixlen = true]; + // } + // + // ``` + Ipv6WithPrefixlen bool `protobuf:"varint,28,opt,name=ipv6_with_prefixlen,json=ipv6WithPrefixlen,oneof"` +} + +type StringRules_IpPrefix struct { + // `ip_prefix` specifies that the field value must be a valid IP (v4 or v6) + // prefix—for example, "192.168.0.0/16" or "2001:0DB8:ABCD:0012::0/64". + // + // The prefix must have all zeros for the unmasked bits. For example, + // "2001:0DB8:ABCD:0012::0/64" designates the left-most 64 bits for the + // prefix, and the remaining 64 bits must be zero. + // + // If the field value isn't a valid IP prefix, an error message will be + // generated. + // + // ```proto + // + // message MyString { + // // must be a valid IP prefix + // string value = 1 [(buf.validate.field).string.ip_prefix = true]; + // } + // + // ``` + IpPrefix bool `protobuf:"varint,29,opt,name=ip_prefix,json=ipPrefix,oneof"` +} + +type StringRules_Ipv4Prefix struct { + // `ipv4_prefix` specifies that the field value must be a valid IPv4 + // prefix, for example "192.168.0.0/16". + // + // The prefix must have all zeros for the unmasked bits. For example, + // "192.168.0.0/16" designates the left-most 16 bits for the prefix, + // and the remaining 16 bits must be zero. + // + // If the field value isn't a valid IPv4 prefix, an error message + // will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid IPv4 prefix + // string value = 1 [(buf.validate.field).string.ipv4_prefix = true]; + // } + // + // ``` + Ipv4Prefix bool `protobuf:"varint,30,opt,name=ipv4_prefix,json=ipv4Prefix,oneof"` +} + +type StringRules_Ipv6Prefix struct { + // `ipv6_prefix` specifies that the field value must be a valid IPv6 prefix—for + // example, "2001:0DB8:ABCD:0012::0/64". + // + // The prefix must have all zeros for the unmasked bits. For example, + // "2001:0DB8:ABCD:0012::0/64" designates the left-most 64 bits for the + // prefix, and the remaining 64 bits must be zero. + // + // If the field value is not a valid IPv6 prefix, an error message will be + // generated. + // + // ```proto + // + // message MyString { + // // must be a valid IPv6 prefix + // string value = 1 [(buf.validate.field).string.ipv6_prefix = true]; + // } + // + // ``` + Ipv6Prefix bool `protobuf:"varint,31,opt,name=ipv6_prefix,json=ipv6Prefix,oneof"` +} + +type StringRules_HostAndPort struct { + // `host_and_port` specifies that the field value must be a valid host/port + // pair—for example, "example.com:8080". + // + // The host can be one of: + // - An IPv4 address in dotted decimal format—for example, "192.168.5.21". + // - An IPv6 address enclosed in square brackets—for example, "[2001:0DB8:ABCD:0012::F1]". + // - A hostname—for example, "example.com". + // + // The port is separated by a colon. It must be non-empty, with a decimal number + // in the range of 0-65535, inclusive. + HostAndPort bool `protobuf:"varint,32,opt,name=host_and_port,json=hostAndPort,oneof"` +} + +type StringRules_Ulid struct { + // `ulid` specifies that the field value must be a valid ULID (Universally Unique + // Lexicographically Sortable Identifier) as defined by the [ULID specification](https://github.com/ulid/spec). + // If the field value isn't a valid ULID, an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid ULID + // string value = 1 [(buf.validate.field).string.ulid = true]; + // } + // + // ``` + Ulid bool `protobuf:"varint,35,opt,name=ulid,oneof"` +} + +type StringRules_ProtobufFqn struct { + // `protobuf_fqn` specifies that the field value must be a valid fully-qualified + // Protobuf name as defined by the [Protobuf Language Specification](https://protobuf.com/docs/language-spec). + // + // A fully-qualified Protobuf name is a dot-separated list of Protobuf identifiers, + // where each identifier starts with a letter or underscore and is followed by zero or + // more letters, underscores, or digits. + // + // Examples: "buf.validate", "google.protobuf.Timestamp", "my_package.MyMessage". + // + // Note: historically, fully-qualified Protobuf names were represented with a leading + // dot (for example, ".buf.validate.StringRules"). Modern Protobuf does not use the + // leading dot, and most fully-qualified names are represented without it. Use + // `protobuf_dot_fqn` if a leading dot is required. + // + // If the field value isn't a valid fully-qualified Protobuf name, an error message + // will be generated. + // + // ```proto + // + // message MyString { + // // value must be a valid fully-qualified Protobuf name + // string value = 1 [(buf.validate.field).string.protobuf_fqn = true]; + // } + // + // ``` + ProtobufFqn bool `protobuf:"varint,37,opt,name=protobuf_fqn,json=protobufFqn,oneof"` +} + +type StringRules_ProtobufDotFqn struct { + // `protobuf_dot_fqn` specifies that the field value must be a valid fully-qualified + // Protobuf name with a leading dot, as defined by the + // [Protobuf Language Specification](https://protobuf.com/docs/language-spec). + // + // A fully-qualified Protobuf name with a leading dot is a dot followed by a + // dot-separated list of Protobuf identifiers, where each identifier starts with a + // letter or underscore and is followed by zero or more letters, underscores, or + // digits. + // + // Examples: ".buf.validate", ".google.protobuf.Timestamp", ".my_package.MyMessage". + // + // Note: this is the historical representation of fully-qualified Protobuf names, + // where a leading dot denotes an absolute reference. Modern Protobuf does not use + // the leading dot, and most fully-qualified names are represented without it. Most + // users will want to use `protobuf_fqn` instead. + // + // If the field value isn't a valid fully-qualified Protobuf name with a leading dot, + // an error message will be generated. + // + // ```proto + // + // message MyString { + // // value must be a valid fully-qualified Protobuf name with a leading dot + // string value = 1 [(buf.validate.field).string.protobuf_dot_fqn = true]; + // } + // + // ``` + ProtobufDotFqn bool `protobuf:"varint,38,opt,name=protobuf_dot_fqn,json=protobufDotFqn,oneof"` +} + +type StringRules_WellKnownRegex struct { + // `well_known_regex` specifies a common well-known pattern + // defined as a regex. If the field value doesn't match the well-known + // regex, an error message will be generated. + // + // ```proto + // + // message MyString { + // // must be a valid HTTP header value + // string value = 1 [(buf.validate.field).string.well_known_regex = KNOWN_REGEX_HTTP_HEADER_VALUE]; + // } + // + // ``` + // + // #### KnownRegex + // + // `well_known_regex` contains some well-known patterns. + // + // | Name | Number | Description | + // |-------------------------------|--------|-------------------------------------------| + // | KNOWN_REGEX_UNSPECIFIED | 0 | | + // | KNOWN_REGEX_HTTP_HEADER_NAME | 1 | HTTP header name as defined by [RFC 7230](https://datatracker.ietf.org/doc/html/rfc7230#section-3.2) | + // | KNOWN_REGEX_HTTP_HEADER_VALUE | 2 | HTTP header value as defined by [RFC 7230](https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.4) | + WellKnownRegex KnownRegex `protobuf:"varint,24,opt,name=well_known_regex,json=wellKnownRegex,enum=buf.validate.KnownRegex,oneof"` +} + +func (*StringRules_Email) isStringRules_WellKnown() {} + +func (*StringRules_Hostname) isStringRules_WellKnown() {} + +func (*StringRules_Ip) isStringRules_WellKnown() {} + +func (*StringRules_Ipv4) isStringRules_WellKnown() {} + +func (*StringRules_Ipv6) isStringRules_WellKnown() {} + +func (*StringRules_Uri) isStringRules_WellKnown() {} + +func (*StringRules_UriRef) isStringRules_WellKnown() {} + +func (*StringRules_Address) isStringRules_WellKnown() {} + +func (*StringRules_Uuid) isStringRules_WellKnown() {} + +func (*StringRules_Tuuid) isStringRules_WellKnown() {} + +func (*StringRules_IpWithPrefixlen) isStringRules_WellKnown() {} + +func (*StringRules_Ipv4WithPrefixlen) isStringRules_WellKnown() {} + +func (*StringRules_Ipv6WithPrefixlen) isStringRules_WellKnown() {} + +func (*StringRules_IpPrefix) isStringRules_WellKnown() {} + +func (*StringRules_Ipv4Prefix) isStringRules_WellKnown() {} + +func (*StringRules_Ipv6Prefix) isStringRules_WellKnown() {} + +func (*StringRules_HostAndPort) isStringRules_WellKnown() {} + +func (*StringRules_Ulid) isStringRules_WellKnown() {} + +func (*StringRules_ProtobufFqn) isStringRules_WellKnown() {} + +func (*StringRules_ProtobufDotFqn) isStringRules_WellKnown() {} + +func (*StringRules_WellKnownRegex) isStringRules_WellKnown() {} + +// BytesRules describe the rules applied to `bytes` values. These rules +// may also be applied to the `google.protobuf.BytesValue` Well-Known-Type. +type BytesRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified bytes + // value. If the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // must be "\x01\x02\x03\x04" + // bytes value = 1 [(buf.validate.field).bytes.const = "\x01\x02\x03\x04"]; + // } + // + // ``` + Const []byte `protobuf:"bytes,1,opt,name=const" json:"const,omitempty"` + // `len` requires the field value to have the specified length in bytes. + // If the field value doesn't match, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // value length must be 4 bytes. + // optional bytes value = 1 [(buf.validate.field).bytes.len = 4]; + // } + // + // ``` + Len *uint64 `protobuf:"varint,13,opt,name=len" json:"len,omitempty"` + // `min_len` requires the field value to have at least the specified minimum + // length in bytes. + // If the field value doesn't meet the requirement, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // value length must be at least 2 bytes. + // optional bytes value = 1 [(buf.validate.field).bytes.min_len = 2]; + // } + // + // ``` + MinLen *uint64 `protobuf:"varint,2,opt,name=min_len,json=minLen" json:"min_len,omitempty"` + // `max_len` requires the field value to have at most the specified maximum + // length in bytes. + // If the field value exceeds the requirement, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // must be at most 6 bytes. + // optional bytes value = 1 [(buf.validate.field).bytes.max_len = 6]; + // } + // + // ``` + MaxLen *uint64 `protobuf:"varint,3,opt,name=max_len,json=maxLen" json:"max_len,omitempty"` + // `pattern` requires the field value to match the specified regular + // expression ([RE2 syntax](https://github.com/google/re2/wiki/Syntax)). + // The value of the field must be valid UTF-8 or validation will fail with a + // runtime error. + // If the field value doesn't match the pattern, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // value must match regex pattern "^[a-zA-Z0-9]+$". + // optional bytes value = 1 [(buf.validate.field).bytes.pattern = "^[a-zA-Z0-9]+$"]; + // } + // + // ``` + Pattern *string `protobuf:"bytes,4,opt,name=pattern" json:"pattern,omitempty"` + // `prefix` requires the field value to have the specified bytes at the + // beginning of the string. + // If the field value doesn't meet the requirement, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // value does not have prefix \x01\x02 + // optional bytes value = 1 [(buf.validate.field).bytes.prefix = "\x01\x02"]; + // } + // + // ``` + Prefix []byte `protobuf:"bytes,5,opt,name=prefix" json:"prefix,omitempty"` + // `suffix` requires the field value to have the specified bytes at the end + // of the string. + // If the field value doesn't meet the requirement, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // value does not have suffix \x03\x04 + // optional bytes value = 1 [(buf.validate.field).bytes.suffix = "\x03\x04"]; + // } + // + // ``` + Suffix []byte `protobuf:"bytes,6,opt,name=suffix" json:"suffix,omitempty"` + // `contains` requires the field value to have the specified bytes anywhere in + // the string. + // If the field value doesn't meet the requirement, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // value does not contain \x02\x03 + // optional bytes value = 1 [(buf.validate.field).bytes.contains = "\x02\x03"]; + // } + // + // ``` + Contains []byte `protobuf:"bytes,7,opt,name=contains" json:"contains,omitempty"` + // `in` requires the field value to be equal to one of the specified + // values. If the field value doesn't match any of the specified values, an + // error message is generated. + // + // ```proto + // + // message MyBytes { + // // value must in ["\x01\x02", "\x02\x03", "\x03\x04"] + // optional bytes value = 1 [(buf.validate.field).bytes.in = {"\x01\x02", "\x02\x03", "\x03\x04"}]; + // } + // + // ``` + In [][]byte `protobuf:"bytes,8,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to be not equal to any of the specified + // values. + // If the field value matches any of the specified values, an error message is + // generated. + // + // ```proto + // + // message MyBytes { + // // value must not in ["\x01\x02", "\x02\x03", "\x03\x04"] + // optional bytes value = 1 [(buf.validate.field).bytes.not_in = {"\x01\x02", "\x02\x03", "\x03\x04"}]; + // } + // + // ``` + NotIn [][]byte `protobuf:"bytes,9,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // WellKnown rules provide advanced rules against common byte + // patterns + // + // Types that are valid to be assigned to WellKnown: + // + // *BytesRules_Ip + // *BytesRules_Ipv4 + // *BytesRules_Ipv6 + // *BytesRules_Uuid + WellKnown isBytesRules_WellKnown `protobuf_oneof:"well_known"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyBytes { + // bytes value = 1 [ + // (buf.validate.field).bytes.example = "\x01\x02", + // (buf.validate.field).bytes.example = "\x02\x03" + // ]; + // } + // + // ``` + Example [][]byte `protobuf:"bytes,14,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *BytesRules) Reset() { + *x = BytesRules{} + mi := &file_buf_validate_validate_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *BytesRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BytesRules) ProtoMessage() {} + +func (x *BytesRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[20] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BytesRules.ProtoReflect.Descriptor instead. +func (*BytesRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{20} +} + +func (x *BytesRules) GetConst() []byte { + if x != nil { + return x.Const + } + return nil +} + +func (x *BytesRules) GetLen() uint64 { + if x != nil && x.Len != nil { + return *x.Len + } + return 0 +} + +func (x *BytesRules) GetMinLen() uint64 { + if x != nil && x.MinLen != nil { + return *x.MinLen + } + return 0 +} + +func (x *BytesRules) GetMaxLen() uint64 { + if x != nil && x.MaxLen != nil { + return *x.MaxLen + } + return 0 +} + +func (x *BytesRules) GetPattern() string { + if x != nil && x.Pattern != nil { + return *x.Pattern + } + return "" +} + +func (x *BytesRules) GetPrefix() []byte { + if x != nil { + return x.Prefix + } + return nil +} + +func (x *BytesRules) GetSuffix() []byte { + if x != nil { + return x.Suffix + } + return nil +} + +func (x *BytesRules) GetContains() []byte { + if x != nil { + return x.Contains + } + return nil +} + +func (x *BytesRules) GetIn() [][]byte { + if x != nil { + return x.In + } + return nil +} + +func (x *BytesRules) GetNotIn() [][]byte { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *BytesRules) GetWellKnown() isBytesRules_WellKnown { + if x != nil { + return x.WellKnown + } + return nil +} + +func (x *BytesRules) GetIp() bool { + if x != nil { + if x, ok := x.WellKnown.(*BytesRules_Ip); ok { + return x.Ip + } + } + return false +} + +func (x *BytesRules) GetIpv4() bool { + if x != nil { + if x, ok := x.WellKnown.(*BytesRules_Ipv4); ok { + return x.Ipv4 + } + } + return false +} + +func (x *BytesRules) GetIpv6() bool { + if x != nil { + if x, ok := x.WellKnown.(*BytesRules_Ipv6); ok { + return x.Ipv6 + } + } + return false +} + +func (x *BytesRules) GetUuid() bool { + if x != nil { + if x, ok := x.WellKnown.(*BytesRules_Uuid); ok { + return x.Uuid + } + } + return false +} + +func (x *BytesRules) GetExample() [][]byte { + if x != nil { + return x.Example + } + return nil +} + +type isBytesRules_WellKnown interface { + isBytesRules_WellKnown() +} + +type BytesRules_Ip struct { + // `ip` ensures that the field `value` is a valid IP address (v4 or v6) in byte format. + // If the field value doesn't meet this rule, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // must be a valid IP address + // optional bytes value = 1 [(buf.validate.field).bytes.ip = true]; + // } + // + // ``` + Ip bool `protobuf:"varint,10,opt,name=ip,oneof"` +} + +type BytesRules_Ipv4 struct { + // `ipv4` ensures that the field `value` is a valid IPv4 address in byte format. + // If the field value doesn't meet this rule, an error message is generated. + // + // ```proto + // + // message MyBytes { + // // must be a valid IPv4 address + // optional bytes value = 1 [(buf.validate.field).bytes.ipv4 = true]; + // } + // + // ``` + Ipv4 bool `protobuf:"varint,11,opt,name=ipv4,oneof"` +} + +type BytesRules_Ipv6 struct { + // `ipv6` ensures that the field `value` is a valid IPv6 address in byte format. + // If the field value doesn't meet this rule, an error message is generated. + // ```proto + // + // message MyBytes { + // // must be a valid IPv6 address + // optional bytes value = 1 [(buf.validate.field).bytes.ipv6 = true]; + // } + // + // ``` + Ipv6 bool `protobuf:"varint,12,opt,name=ipv6,oneof"` +} + +type BytesRules_Uuid struct { + // `uuid` ensures that the field value encodes 128-bit UUID data as defined + // by [RFC 4122](https://datatracker.ietf.org/doc/html/rfc4122#section-4.1.2). + // The field must contain exactly 16 bytes representing the UUID. If the + // field value isn't a valid UUID, an error message will be generated. + // + // ```proto + // + // message MyBytes { + // // must be a valid UUID + // optional bytes value = 1 [(buf.validate.field).bytes.uuid = true]; + // } + // + // ``` + Uuid bool `protobuf:"varint,15,opt,name=uuid,oneof"` +} + +func (*BytesRules_Ip) isBytesRules_WellKnown() {} + +func (*BytesRules_Ipv4) isBytesRules_WellKnown() {} + +func (*BytesRules_Ipv6) isBytesRules_WellKnown() {} + +func (*BytesRules_Uuid) isBytesRules_WellKnown() {} + +// EnumRules describe the rules applied to `enum` values. +type EnumRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` requires the field value to exactly match the specified enum value. + // If the field value doesn't match, an error message is generated. + // + // ```proto + // + // enum MyEnum { + // MY_ENUM_UNSPECIFIED = 0; + // MY_ENUM_VALUE1 = 1; + // MY_ENUM_VALUE2 = 2; + // } + // + // message MyMessage { + // // The field `value` must be exactly MY_ENUM_VALUE1. + // MyEnum value = 1 [(buf.validate.field).enum.const = 1]; + // } + // + // ``` + Const *int32 `protobuf:"varint,1,opt,name=const" json:"const,omitempty"` + // `defined_only` requires the field value to be one of the defined values for + // this enum, failing on any undefined value. + // + // ```proto + // + // enum MyEnum { + // MY_ENUM_UNSPECIFIED = 0; + // MY_ENUM_VALUE1 = 1; + // MY_ENUM_VALUE2 = 2; + // } + // + // message MyMessage { + // // The field `value` must be a defined value of MyEnum. + // MyEnum value = 1 [(buf.validate.field).enum.defined_only = true]; + // } + // + // ``` + DefinedOnly *bool `protobuf:"varint,2,opt,name=defined_only,json=definedOnly" json:"defined_only,omitempty"` + // `in` requires the field value to be equal to one of the + // specified enum values. If the field value doesn't match any of the + // specified values, an error message is generated. + // + // ```proto + // + // enum MyEnum { + // MY_ENUM_UNSPECIFIED = 0; + // MY_ENUM_VALUE1 = 1; + // MY_ENUM_VALUE2 = 2; + // } + // + // message MyMessage { + // // The field `value` must be equal to one of the specified values. + // MyEnum value = 1 [(buf.validate.field).enum = { in: [1, 2]}]; + // } + // + // ``` + In []int32 `protobuf:"varint,3,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to be not equal to any of the + // specified enum values. If the field value matches one of the specified + // values, an error message is generated. + // + // ```proto + // + // enum MyEnum { + // MY_ENUM_UNSPECIFIED = 0; + // MY_ENUM_VALUE1 = 1; + // MY_ENUM_VALUE2 = 2; + // } + // + // message MyMessage { + // // The field `value` must not be equal to any of the specified values. + // MyEnum value = 1 [(buf.validate.field).enum = { not_in: [1, 2]}]; + // } + // + // ``` + NotIn []int32 `protobuf:"varint,4,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // enum MyEnum { + // MY_ENUM_UNSPECIFIED = 0; + // MY_ENUM_VALUE1 = 1; + // MY_ENUM_VALUE2 = 2; + // } + // + // message MyMessage { + // (buf.validate.field).enum.example = 1, + // (buf.validate.field).enum.example = 2 + // } + // + // ``` + Example []int32 `protobuf:"varint,5,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EnumRules) Reset() { + *x = EnumRules{} + mi := &file_buf_validate_validate_proto_msgTypes[21] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EnumRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EnumRules) ProtoMessage() {} + +func (x *EnumRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[21] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EnumRules.ProtoReflect.Descriptor instead. +func (*EnumRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{21} +} + +func (x *EnumRules) GetConst() int32 { + if x != nil && x.Const != nil { + return *x.Const + } + return 0 +} + +func (x *EnumRules) GetDefinedOnly() bool { + if x != nil && x.DefinedOnly != nil { + return *x.DefinedOnly + } + return false +} + +func (x *EnumRules) GetIn() []int32 { + if x != nil { + return x.In + } + return nil +} + +func (x *EnumRules) GetNotIn() []int32 { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *EnumRules) GetExample() []int32 { + if x != nil { + return x.Example + } + return nil +} + +// RepeatedRules describe the rules applied to `repeated` values. +type RepeatedRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `min_items` requires that this field must contain at least the specified + // minimum number of items. + // + // Note that `min_items = 1` is equivalent to setting a field as `required`. + // + // ```proto + // + // message MyRepeated { + // // value must contain at least 2 items + // repeated string value = 1 [(buf.validate.field).repeated.min_items = 2]; + // } + // + // ``` + MinItems *uint64 `protobuf:"varint,1,opt,name=min_items,json=minItems" json:"min_items,omitempty"` + // `max_items` denotes that this field must not exceed a + // certain number of items as the upper limit. If the field contains more + // items than specified, an error message will be generated, requiring the + // field to maintain no more than the specified number of items. + // + // ```proto + // + // message MyRepeated { + // // value must contain no more than 3 item(s) + // repeated string value = 1 [(buf.validate.field).repeated.max_items = 3]; + // } + // + // ``` + MaxItems *uint64 `protobuf:"varint,2,opt,name=max_items,json=maxItems" json:"max_items,omitempty"` + // `unique` indicates that all elements in this field must + // be unique. This rule is strictly applicable to scalar and enum + // types, with message types not being supported. + // + // ```proto + // + // message MyRepeated { + // // repeated value must contain unique items + // repeated string value = 1 [(buf.validate.field).repeated.unique = true]; + // } + // + // ``` + Unique *bool `protobuf:"varint,3,opt,name=unique" json:"unique,omitempty"` + // `items` details the rules to be applied to each item + // in the field. Even for repeated message fields, validation is executed + // against each item unless `ignore` is specified. + // + // ```proto + // + // message MyRepeated { + // // The items in the field `value` must follow the specified rules. + // repeated string value = 1 [(buf.validate.field).repeated.items = { + // string: { + // min_len: 3 + // max_len: 10 + // } + // }]; + // } + // + // ``` + // + // Note that the `required` rule does not apply. Repeated items + // cannot be unset. + Items *FieldRules `protobuf:"bytes,4,opt,name=items" json:"items,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RepeatedRules) Reset() { + *x = RepeatedRules{} + mi := &file_buf_validate_validate_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RepeatedRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RepeatedRules) ProtoMessage() {} + +func (x *RepeatedRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[22] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RepeatedRules.ProtoReflect.Descriptor instead. +func (*RepeatedRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{22} +} + +func (x *RepeatedRules) GetMinItems() uint64 { + if x != nil && x.MinItems != nil { + return *x.MinItems + } + return 0 +} + +func (x *RepeatedRules) GetMaxItems() uint64 { + if x != nil && x.MaxItems != nil { + return *x.MaxItems + } + return 0 +} + +func (x *RepeatedRules) GetUnique() bool { + if x != nil && x.Unique != nil { + return *x.Unique + } + return false +} + +func (x *RepeatedRules) GetItems() *FieldRules { + if x != nil { + return x.Items + } + return nil +} + +// MapRules describe the rules applied to `map` values. +type MapRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Specifies the minimum number of key-value pairs allowed. If the field has + // fewer key-value pairs than specified, an error message is generated. + // + // ```proto + // + // message MyMap { + // // The field `value` must have at least 2 key-value pairs. + // map value = 1 [(buf.validate.field).map.min_pairs = 2]; + // } + // + // ``` + MinPairs *uint64 `protobuf:"varint,1,opt,name=min_pairs,json=minPairs" json:"min_pairs,omitempty"` + // Specifies the maximum number of key-value pairs allowed. If the field has + // more key-value pairs than specified, an error message is generated. + // + // ```proto + // + // message MyMap { + // // The field `value` must have at most 3 key-value pairs. + // map value = 1 [(buf.validate.field).map.max_pairs = 3]; + // } + // + // ``` + MaxPairs *uint64 `protobuf:"varint,2,opt,name=max_pairs,json=maxPairs" json:"max_pairs,omitempty"` + // Specifies the rules to be applied to each key in the field. + // + // ```proto + // + // message MyMap { + // // The keys in the field `value` must follow the specified rules. + // map value = 1 [(buf.validate.field).map.keys = { + // string: { + // min_len: 3 + // max_len: 10 + // } + // }]; + // } + // + // ``` + // + // Note that the `required` rule does not apply. Map keys cannot be unset. + Keys *FieldRules `protobuf:"bytes,4,opt,name=keys" json:"keys,omitempty"` + // Specifies the rules to be applied to the value of each key in the + // field. Message values will still have their validations evaluated unless + // `ignore` is specified. + // + // ```proto + // + // message MyMap { + // // The values in the field `value` must follow the specified rules. + // map value = 1 [(buf.validate.field).map.values = { + // string: { + // min_len: 5 + // max_len: 20 + // } + // }]; + // } + // + // ``` + // Note that the `required` rule does not apply. Map values cannot be unset. + Values *FieldRules `protobuf:"bytes,5,opt,name=values" json:"values,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MapRules) Reset() { + *x = MapRules{} + mi := &file_buf_validate_validate_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MapRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MapRules) ProtoMessage() {} + +func (x *MapRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[23] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MapRules.ProtoReflect.Descriptor instead. +func (*MapRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{23} +} + +func (x *MapRules) GetMinPairs() uint64 { + if x != nil && x.MinPairs != nil { + return *x.MinPairs + } + return 0 +} + +func (x *MapRules) GetMaxPairs() uint64 { + if x != nil && x.MaxPairs != nil { + return *x.MaxPairs + } + return 0 +} + +func (x *MapRules) GetKeys() *FieldRules { + if x != nil { + return x.Keys + } + return nil +} + +func (x *MapRules) GetValues() *FieldRules { + if x != nil { + return x.Values + } + return nil +} + +// AnyRules describe rules applied exclusively to the `google.protobuf.Any` well-known type. +type AnyRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `in` requires the field's `type_url` to be equal to one of the + // specified values. If it doesn't match any of the specified values, an error + // message is generated. + // + // ```proto + // + // message MyAny { + // // The `value` field must have a `type_url` equal to one of the specified values. + // google.protobuf.Any value = 1 [(buf.validate.field).any = { + // in: ["type.googleapis.com/MyType1", "type.googleapis.com/MyType2"] + // }]; + // } + // + // ``` + In []string `protobuf:"bytes,2,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field's type_url to be not equal to any of the specified values. If it matches any of the specified values, an error message is generated. + // + // ```proto + // + // message MyAny { + // // The `value` field must not have a `type_url` equal to any of the specified values. + // google.protobuf.Any value = 1 [(buf.validate.field).any = { + // not_in: ["type.googleapis.com/ForbiddenType1", "type.googleapis.com/ForbiddenType2"] + // }]; + // } + // + // ``` + NotIn []string `protobuf:"bytes,3,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AnyRules) Reset() { + *x = AnyRules{} + mi := &file_buf_validate_validate_proto_msgTypes[24] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AnyRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AnyRules) ProtoMessage() {} + +func (x *AnyRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[24] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AnyRules.ProtoReflect.Descriptor instead. +func (*AnyRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{24} +} + +func (x *AnyRules) GetIn() []string { + if x != nil { + return x.In + } + return nil +} + +func (x *AnyRules) GetNotIn() []string { + if x != nil { + return x.NotIn + } + return nil +} + +// DurationRules describe the rules applied exclusively to the `google.protobuf.Duration` well-known type. +type DurationRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` dictates that the field must match the specified value of the `google.protobuf.Duration` type exactly. + // If the field's value deviates from the specified value, an error message + // will be generated. + // + // ```proto + // + // message MyDuration { + // // value must equal 5s + // google.protobuf.Duration value = 1 [(buf.validate.field).duration.const = "5s"]; + // } + // + // ``` + Const *durationpb.Duration `protobuf:"bytes,2,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *DurationRules_Lt + // *DurationRules_Lte + LessThan isDurationRules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *DurationRules_Gt + // *DurationRules_Gte + GreaterThan isDurationRules_GreaterThan `protobuf_oneof:"greater_than"` + // `in` asserts that the field must be equal to one of the specified values of the `google.protobuf.Duration` type. + // If the field's value doesn't correspond to any of the specified values, + // an error message will be generated. + // + // ```proto + // + // message MyDuration { + // // must be in list [1s, 2s, 3s] + // google.protobuf.Duration value = 1 [(buf.validate.field).duration.in = ["1s", "2s", "3s"]]; + // } + // + // ``` + In []*durationpb.Duration `protobuf:"bytes,7,rep,name=in" json:"in,omitempty"` + // `not_in` denotes that the field must not be equal to + // any of the specified values of the `google.protobuf.Duration` type. + // If the field's value matches any of these values, an error message will be + // generated. + // + // ```proto + // + // message MyDuration { + // // value must not be in list [1s, 2s, 3s] + // google.protobuf.Duration value = 1 [(buf.validate.field).duration.not_in = ["1s", "2s", "3s"]]; + // } + // + // ``` + NotIn []*durationpb.Duration `protobuf:"bytes,8,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyDuration { + // google.protobuf.Duration value = 1 [ + // (buf.validate.field).duration.example = { seconds: 1 }, + // (buf.validate.field).duration.example = { seconds: 2 }, + // ]; + // } + // + // ``` + Example []*durationpb.Duration `protobuf:"bytes,9,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DurationRules) Reset() { + *x = DurationRules{} + mi := &file_buf_validate_validate_proto_msgTypes[25] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DurationRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DurationRules) ProtoMessage() {} + +func (x *DurationRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[25] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DurationRules.ProtoReflect.Descriptor instead. +func (*DurationRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{25} +} + +func (x *DurationRules) GetConst() *durationpb.Duration { + if x != nil { + return x.Const + } + return nil +} + +func (x *DurationRules) GetLessThan() isDurationRules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *DurationRules) GetLt() *durationpb.Duration { + if x != nil { + if x, ok := x.LessThan.(*DurationRules_Lt); ok { + return x.Lt + } + } + return nil +} + +func (x *DurationRules) GetLte() *durationpb.Duration { + if x != nil { + if x, ok := x.LessThan.(*DurationRules_Lte); ok { + return x.Lte + } + } + return nil +} + +func (x *DurationRules) GetGreaterThan() isDurationRules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *DurationRules) GetGt() *durationpb.Duration { + if x != nil { + if x, ok := x.GreaterThan.(*DurationRules_Gt); ok { + return x.Gt + } + } + return nil +} + +func (x *DurationRules) GetGte() *durationpb.Duration { + if x != nil { + if x, ok := x.GreaterThan.(*DurationRules_Gte); ok { + return x.Gte + } + } + return nil +} + +func (x *DurationRules) GetIn() []*durationpb.Duration { + if x != nil { + return x.In + } + return nil +} + +func (x *DurationRules) GetNotIn() []*durationpb.Duration { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *DurationRules) GetExample() []*durationpb.Duration { + if x != nil { + return x.Example + } + return nil +} + +type isDurationRules_LessThan interface { + isDurationRules_LessThan() +} + +type DurationRules_Lt struct { + // `lt` stipulates that the field must be less than the specified value of the `google.protobuf.Duration` type, + // exclusive. If the field's value is greater than or equal to the specified + // value, an error message will be generated. + // + // ```proto + // + // message MyDuration { + // // must be less than 5s + // google.protobuf.Duration value = 1 [(buf.validate.field).duration.lt = "5s"]; + // } + // + // ``` + Lt *durationpb.Duration `protobuf:"bytes,3,opt,name=lt,oneof"` +} + +type DurationRules_Lte struct { + // `lte` indicates that the field must be less than or equal to the specified + // value of the `google.protobuf.Duration` type, inclusive. If the field's value is greater than the specified value, + // an error message will be generated. + // + // ```proto + // + // message MyDuration { + // // must be less than or equal to 10s + // google.protobuf.Duration value = 1 [(buf.validate.field).duration.lte = "10s"]; + // } + // + // ``` + Lte *durationpb.Duration `protobuf:"bytes,4,opt,name=lte,oneof"` +} + +func (*DurationRules_Lt) isDurationRules_LessThan() {} + +func (*DurationRules_Lte) isDurationRules_LessThan() {} + +type isDurationRules_GreaterThan interface { + isDurationRules_GreaterThan() +} + +type DurationRules_Gt struct { + // `gt` requires the duration field value to be greater than the specified + // value (exclusive). If the value of `gt` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyDuration { + // // duration must be greater than 5s [duration.gt] + // google.protobuf.Duration value = 1 [(buf.validate.field).duration.gt = { seconds: 5 }]; + // + // // duration must be greater than 5s and less than 10s [duration.gt_lt] + // google.protobuf.Duration another_value = 2 [(buf.validate.field).duration = { gt: { seconds: 5 }, lt: { seconds: 10 } }]; + // + // // duration must be greater than 10s or less than 5s [duration.gt_lt_exclusive] + // google.protobuf.Duration other_value = 3 [(buf.validate.field).duration = { gt: { seconds: 10 }, lt: { seconds: 5 } }]; + // } + // + // ``` + Gt *durationpb.Duration `protobuf:"bytes,5,opt,name=gt,oneof"` +} + +type DurationRules_Gte struct { + // `gte` requires the duration field value to be greater than or equal to the + // specified value (exclusive). If the value of `gte` is larger than a + // specified `lt` or `lte`, the range is reversed, and the field value must + // be outside the specified range. If the field value doesn't meet the + // required conditions, an error message is generated. + // + // ```proto + // + // message MyDuration { + // // duration must be greater than or equal to 5s [duration.gte] + // google.protobuf.Duration value = 1 [(buf.validate.field).duration.gte = { seconds: 5 }]; + // + // // duration must be greater than or equal to 5s and less than 10s [duration.gte_lt] + // google.protobuf.Duration another_value = 2 [(buf.validate.field).duration = { gte: { seconds: 5 }, lt: { seconds: 10 } }]; + // + // // duration must be greater than or equal to 10s or less than 5s [duration.gte_lt_exclusive] + // google.protobuf.Duration other_value = 3 [(buf.validate.field).duration = { gte: { seconds: 10 }, lt: { seconds: 5 } }]; + // } + // + // ``` + Gte *durationpb.Duration `protobuf:"bytes,6,opt,name=gte,oneof"` +} + +func (*DurationRules_Gt) isDurationRules_GreaterThan() {} + +func (*DurationRules_Gte) isDurationRules_GreaterThan() {} + +// FieldMaskRules describe rules applied exclusively to the `google.protobuf.FieldMask` well-known type. +type FieldMaskRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` dictates that the field must match the specified value of the `google.protobuf.FieldMask` type exactly. + // If the field's value deviates from the specified value, an error message + // will be generated. + // + // ```proto + // + // message MyFieldMask { + // // value must equal ["a"] + // google.protobuf.FieldMask value = 1 [(buf.validate.field).field_mask.const = { + // paths: ["a"] + // }]; + // } + // + // ``` + Const *fieldmaskpb.FieldMask `protobuf:"bytes,1,opt,name=const" json:"const,omitempty"` + // `in` requires the field value to only contain paths matching specified + // values or their subpaths. + // If any of the field value's paths doesn't match the rule, + // an error message is generated. + // See: https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask + // + // ```proto + // + // message MyFieldMask { + // // The `value` FieldMask must only contain paths listed in `in`. + // google.protobuf.FieldMask value = 1 [(buf.validate.field).field_mask = { + // in: ["a", "b", "c.a"] + // }]; + // } + // + // ``` + In []string `protobuf:"bytes,2,rep,name=in" json:"in,omitempty"` + // `not_in` requires the field value to not contain paths matching specified + // values or their subpaths. + // If any of the field value's paths matches the rule, + // an error message is generated. + // See: https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask + // + // ```proto + // + // message MyFieldMask { + // // The `value` FieldMask shall not contain paths listed in `not_in`. + // google.protobuf.FieldMask value = 1 [(buf.validate.field).field_mask = { + // not_in: ["forbidden", "immutable", "c.a"] + // }]; + // } + // + // ``` + NotIn []string `protobuf:"bytes,3,rep,name=not_in,json=notIn" json:"not_in,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyFieldMask { + // google.protobuf.FieldMask value = 1 [ + // (buf.validate.field).field_mask.example = { paths: ["a", "b"] }, + // (buf.validate.field).field_mask.example = { paths: ["c.a", "d"] }, + // ]; + // } + // + // ``` + Example []*fieldmaskpb.FieldMask `protobuf:"bytes,4,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FieldMaskRules) Reset() { + *x = FieldMaskRules{} + mi := &file_buf_validate_validate_proto_msgTypes[26] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FieldMaskRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FieldMaskRules) ProtoMessage() {} + +func (x *FieldMaskRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[26] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FieldMaskRules.ProtoReflect.Descriptor instead. +func (*FieldMaskRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{26} +} + +func (x *FieldMaskRules) GetConst() *fieldmaskpb.FieldMask { + if x != nil { + return x.Const + } + return nil +} + +func (x *FieldMaskRules) GetIn() []string { + if x != nil { + return x.In + } + return nil +} + +func (x *FieldMaskRules) GetNotIn() []string { + if x != nil { + return x.NotIn + } + return nil +} + +func (x *FieldMaskRules) GetExample() []*fieldmaskpb.FieldMask { + if x != nil { + return x.Example + } + return nil +} + +// TimestampRules describe the rules applied exclusively to the `google.protobuf.Timestamp` well-known type. +type TimestampRules struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `const` dictates that this field, of the `google.protobuf.Timestamp` type, must exactly match the specified value. If the field value doesn't correspond to the specified timestamp, an error message will be generated. + // + // ```proto + // + // message MyTimestamp { + // // value must equal 2023-05-03T10:00:00Z + // google.protobuf.Timestamp created_at = 1 [(buf.validate.field).timestamp.const = {seconds: 1727998800}]; + // } + // + // ``` + Const *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=const" json:"const,omitempty"` + // Types that are valid to be assigned to LessThan: + // + // *TimestampRules_Lt + // *TimestampRules_Lte + // *TimestampRules_LtNow + LessThan isTimestampRules_LessThan `protobuf_oneof:"less_than"` + // Types that are valid to be assigned to GreaterThan: + // + // *TimestampRules_Gt + // *TimestampRules_Gte + // *TimestampRules_GtNow + GreaterThan isTimestampRules_GreaterThan `protobuf_oneof:"greater_than"` + // `within` specifies that this field, of the `google.protobuf.Timestamp` type, must be within the specified duration of the current time. If the field value isn't within the duration, an error message is generated. + // + // ```proto + // + // message MyTimestamp { + // // must be within 1 hour of now + // google.protobuf.Timestamp created_at = 1 [(buf.validate.field).timestamp.within = {seconds: 3600}]; + // } + // + // ``` + Within *durationpb.Duration `protobuf:"bytes,9,opt,name=within" json:"within,omitempty"` + // `example` specifies values that the field may have. These values SHOULD + // conform to other rules. `example` values will not impact validation + // but may be used as helpful guidance on how to populate the given field. + // + // ```proto + // + // message MyTimestamp { + // google.protobuf.Timestamp value = 1 [ + // (buf.validate.field).timestamp.example = { seconds: 1672444800 }, + // (buf.validate.field).timestamp.example = { seconds: 1672531200 }, + // ]; + // } + // + // ``` + Example []*timestamppb.Timestamp `protobuf:"bytes,10,rep,name=example" json:"example,omitempty"` + extensionFields protoimpl.ExtensionFields + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TimestampRules) Reset() { + *x = TimestampRules{} + mi := &file_buf_validate_validate_proto_msgTypes[27] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TimestampRules) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TimestampRules) ProtoMessage() {} + +func (x *TimestampRules) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[27] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TimestampRules.ProtoReflect.Descriptor instead. +func (*TimestampRules) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{27} +} + +func (x *TimestampRules) GetConst() *timestamppb.Timestamp { + if x != nil { + return x.Const + } + return nil +} + +func (x *TimestampRules) GetLessThan() isTimestampRules_LessThan { + if x != nil { + return x.LessThan + } + return nil +} + +func (x *TimestampRules) GetLt() *timestamppb.Timestamp { + if x != nil { + if x, ok := x.LessThan.(*TimestampRules_Lt); ok { + return x.Lt + } + } + return nil +} + +func (x *TimestampRules) GetLte() *timestamppb.Timestamp { + if x != nil { + if x, ok := x.LessThan.(*TimestampRules_Lte); ok { + return x.Lte + } + } + return nil +} + +func (x *TimestampRules) GetLtNow() bool { + if x != nil { + if x, ok := x.LessThan.(*TimestampRules_LtNow); ok { + return x.LtNow + } + } + return false +} + +func (x *TimestampRules) GetGreaterThan() isTimestampRules_GreaterThan { + if x != nil { + return x.GreaterThan + } + return nil +} + +func (x *TimestampRules) GetGt() *timestamppb.Timestamp { + if x != nil { + if x, ok := x.GreaterThan.(*TimestampRules_Gt); ok { + return x.Gt + } + } + return nil +} + +func (x *TimestampRules) GetGte() *timestamppb.Timestamp { + if x != nil { + if x, ok := x.GreaterThan.(*TimestampRules_Gte); ok { + return x.Gte + } + } + return nil +} + +func (x *TimestampRules) GetGtNow() bool { + if x != nil { + if x, ok := x.GreaterThan.(*TimestampRules_GtNow); ok { + return x.GtNow + } + } + return false +} + +func (x *TimestampRules) GetWithin() *durationpb.Duration { + if x != nil { + return x.Within + } + return nil +} + +func (x *TimestampRules) GetExample() []*timestamppb.Timestamp { + if x != nil { + return x.Example + } + return nil +} + +type isTimestampRules_LessThan interface { + isTimestampRules_LessThan() +} + +type TimestampRules_Lt struct { + // `lt` requires the timestamp field value to be less than the specified value (field < value). If the field value doesn't meet the required conditions, an error message is generated. + // + // ```proto + // + // message MyTimestamp { + // // timestamp must be less than '2023-01-01T00:00:00Z' [timestamp.lt] + // google.protobuf.Timestamp value = 1 [(buf.validate.field).timestamp.lt = { seconds: 1672444800 }]; + // } + // + // ``` + Lt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=lt,oneof"` +} + +type TimestampRules_Lte struct { + // `lte` requires the timestamp field value to be less than or equal to the specified value (field <= value). If the field value doesn't meet the required conditions, an error message is generated. + // + // ```proto + // + // message MyTimestamp { + // // timestamp must be less than or equal to '2023-05-14T00:00:00Z' [timestamp.lte] + // google.protobuf.Timestamp value = 1 [(buf.validate.field).timestamp.lte = { seconds: 1678867200 }]; + // } + // + // ``` + Lte *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=lte,oneof"` +} + +type TimestampRules_LtNow struct { + // `lt_now` specifies that this field, of the `google.protobuf.Timestamp` type, must be less than the current time. `lt_now` can only be used with the `within` rule. + // + // ```proto + // + // message MyTimestamp { + // // must be less than now + // google.protobuf.Timestamp created_at = 1 [(buf.validate.field).timestamp.lt_now = true]; + // } + // + // ``` + LtNow bool `protobuf:"varint,7,opt,name=lt_now,json=ltNow,oneof"` +} + +func (*TimestampRules_Lt) isTimestampRules_LessThan() {} + +func (*TimestampRules_Lte) isTimestampRules_LessThan() {} + +func (*TimestampRules_LtNow) isTimestampRules_LessThan() {} + +type isTimestampRules_GreaterThan interface { + isTimestampRules_GreaterThan() +} + +type TimestampRules_Gt struct { + // `gt` requires the timestamp field value to be greater than the specified + // value (exclusive). If the value of `gt` is larger than a specified `lt` + // or `lte`, the range is reversed, and the field value must be outside the + // specified range. If the field value doesn't meet the required conditions, + // an error message is generated. + // + // ```proto + // + // message MyTimestamp { + // // timestamp must be greater than '2023-01-01T00:00:00Z' [timestamp.gt] + // google.protobuf.Timestamp value = 1 [(buf.validate.field).timestamp.gt = { seconds: 1672444800 }]; + // + // // timestamp must be greater than '2023-01-01T00:00:00Z' and less than '2023-01-02T00:00:00Z' [timestamp.gt_lt] + // google.protobuf.Timestamp another_value = 2 [(buf.validate.field).timestamp = { gt: { seconds: 1672444800 }, lt: { seconds: 1672531200 } }]; + // + // // timestamp must be greater than '2023-01-02T00:00:00Z' or less than '2023-01-01T00:00:00Z' [timestamp.gt_lt_exclusive] + // google.protobuf.Timestamp other_value = 3 [(buf.validate.field).timestamp = { gt: { seconds: 1672531200 }, lt: { seconds: 1672444800 } }]; + // } + // + // ``` + Gt *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=gt,oneof"` +} + +type TimestampRules_Gte struct { + // `gte` requires the timestamp field value to be greater than or equal to the + // specified value (exclusive). If the value of `gte` is larger than a + // specified `lt` or `lte`, the range is reversed, and the field value + // must be outside the specified range. If the field value doesn't meet + // the required conditions, an error message is generated. + // + // ```proto + // + // message MyTimestamp { + // // timestamp must be greater than or equal to '2023-01-01T00:00:00Z' [timestamp.gte] + // google.protobuf.Timestamp value = 1 [(buf.validate.field).timestamp.gte = { seconds: 1672444800 }]; + // + // // timestamp must be greater than or equal to '2023-01-01T00:00:00Z' and less than '2023-01-02T00:00:00Z' [timestamp.gte_lt] + // google.protobuf.Timestamp another_value = 2 [(buf.validate.field).timestamp = { gte: { seconds: 1672444800 }, lt: { seconds: 1672531200 } }]; + // + // // timestamp must be greater than or equal to '2023-01-02T00:00:00Z' or less than '2023-01-01T00:00:00Z' [timestamp.gte_lt_exclusive] + // google.protobuf.Timestamp other_value = 3 [(buf.validate.field).timestamp = { gte: { seconds: 1672531200 }, lt: { seconds: 1672444800 } }]; + // } + // + // ``` + Gte *timestamppb.Timestamp `protobuf:"bytes,6,opt,name=gte,oneof"` +} + +type TimestampRules_GtNow struct { + // `gt_now` specifies that this field, of the `google.protobuf.Timestamp` type, must be greater than the current time. `gt_now` can only be used with the `within` rule. + // + // ```proto + // + // message MyTimestamp { + // // must be greater than now + // google.protobuf.Timestamp created_at = 1 [(buf.validate.field).timestamp.gt_now = true]; + // } + // + // ``` + GtNow bool `protobuf:"varint,8,opt,name=gt_now,json=gtNow,oneof"` +} + +func (*TimestampRules_Gt) isTimestampRules_GreaterThan() {} + +func (*TimestampRules_Gte) isTimestampRules_GreaterThan() {} + +func (*TimestampRules_GtNow) isTimestampRules_GreaterThan() {} + +// `Violations` is a collection of `Violation` messages. This message type is returned by +// Protovalidate when a proto message fails to meet the requirements set by the `Rule` validation rules. +// Each individual violation is represented by a `Violation` message. +type Violations struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `violations` is a repeated field that contains all the `Violation` messages corresponding to the violations detected. + Violations []*Violation `protobuf:"bytes,1,rep,name=violations" json:"violations,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Violations) Reset() { + *x = Violations{} + mi := &file_buf_validate_validate_proto_msgTypes[28] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Violations) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Violations) ProtoMessage() {} + +func (x *Violations) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[28] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Violations.ProtoReflect.Descriptor instead. +func (*Violations) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{28} +} + +func (x *Violations) GetViolations() []*Violation { + if x != nil { + return x.Violations + } + return nil +} + +// `Violation` represents a single instance where a validation rule, expressed +// as a `Rule`, was not met. It provides information about the field that +// caused the violation, the specific rule that wasn't fulfilled, and a +// human-readable error message. +// +// For example, consider the following message: +// +// ```proto +// +// message User { +// int32 age = 1 [(buf.validate.field).cel = { +// id: "user.age", +// expression: "this < 18 ? 'User must be at least 18 years old' : ''", +// }]; +// } +// +// ``` +// +// It could produce the following violation: +// +// ```json +// +// { +// "ruleId": "user.age", +// "message": "User must be at least 18 years old", +// "field": { +// "elements": [ +// { +// "fieldNumber": 1, +// "fieldName": "age", +// "fieldType": "TYPE_INT32" +// } +// ] +// }, +// "rule": { +// "elements": [ +// { +// "fieldNumber": 23, +// "fieldName": "cel", +// "fieldType": "TYPE_MESSAGE", +// "index": "0" +// } +// ] +// } +// } +// +// ``` +type Violation struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `field` is a machine-readable path to the field that failed validation. + // This could be a nested field, in which case the path will include all the parent fields leading to the actual field that caused the violation. + // + // For example, consider the following message: + // + // ```proto + // + // message Message { + // bool a = 1 [(buf.validate.field).required = true]; + // } + // + // ``` + // + // It could produce the following violation: + // + // ```textproto + // + // violation { + // field { element { field_number: 1, field_name: "a", field_type: 8 } } + // ... + // } + // + // ``` + Field *FieldPath `protobuf:"bytes,5,opt,name=field" json:"field,omitempty"` + // `rule` is a machine-readable path that points to the specific rule that failed validation. + // This will be a nested field starting from the FieldRules of the field that failed validation. + // For custom rules, this will provide the path of the rule, e.g. `cel[0]`. + // + // For example, consider the following message: + // + // ```proto + // + // message Message { + // bool a = 1 [(buf.validate.field).required = true]; + // bool b = 2 [(buf.validate.field).cel = { + // id: "custom_rule", + // expression: "!this ? 'b must be true': ''" + // }] + // } + // + // ``` + // + // It could produce the following violations: + // + // ```textproto + // + // violation { + // rule { element { field_number: 25, field_name: "required", field_type: 8 } } + // ... + // } + // + // violation { + // rule { element { field_number: 23, field_name: "cel", field_type: 11, index: 0 } } + // ... + // } + // + // ``` + Rule *FieldPath `protobuf:"bytes,6,opt,name=rule" json:"rule,omitempty"` + // `rule_id` is the unique identifier of the `Rule` that was not fulfilled. + // This is the same `id` that was specified in the `Rule` message, allowing easy tracing of which rule was violated. + RuleId *string `protobuf:"bytes,2,opt,name=rule_id,json=ruleId" json:"rule_id,omitempty"` + // `message` is a human-readable error message that describes the nature of the violation. + // This can be the default error message from the violated `Rule`, or it can be a custom message that gives more context about the violation. + Message *string `protobuf:"bytes,3,opt,name=message" json:"message,omitempty"` + // `for_key` indicates whether the violation was caused by a map key, rather than a value. + ForKey *bool `protobuf:"varint,4,opt,name=for_key,json=forKey" json:"for_key,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Violation) Reset() { + *x = Violation{} + mi := &file_buf_validate_validate_proto_msgTypes[29] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Violation) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Violation) ProtoMessage() {} + +func (x *Violation) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[29] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Violation.ProtoReflect.Descriptor instead. +func (*Violation) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{29} +} + +func (x *Violation) GetField() *FieldPath { + if x != nil { + return x.Field + } + return nil +} + +func (x *Violation) GetRule() *FieldPath { + if x != nil { + return x.Rule + } + return nil +} + +func (x *Violation) GetRuleId() string { + if x != nil && x.RuleId != nil { + return *x.RuleId + } + return "" +} + +func (x *Violation) GetMessage() string { + if x != nil && x.Message != nil { + return *x.Message + } + return "" +} + +func (x *Violation) GetForKey() bool { + if x != nil && x.ForKey != nil { + return *x.ForKey + } + return false +} + +// `FieldPath` provides a path to a nested protobuf field. +// +// This message provides enough information to render a dotted field path even without protobuf descriptors. +// It also provides enough information to resolve a nested field through unknown wire data. +type FieldPath struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `elements` contains each element of the path, starting from the root and recursing downward. + Elements []*FieldPathElement `protobuf:"bytes,1,rep,name=elements" json:"elements,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FieldPath) Reset() { + *x = FieldPath{} + mi := &file_buf_validate_validate_proto_msgTypes[30] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FieldPath) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FieldPath) ProtoMessage() {} + +func (x *FieldPath) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[30] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FieldPath.ProtoReflect.Descriptor instead. +func (*FieldPath) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{30} +} + +func (x *FieldPath) GetElements() []*FieldPathElement { + if x != nil { + return x.Elements + } + return nil +} + +// `FieldPathElement` provides enough information to nest through a single protobuf field. +// +// If the selected field is a map or repeated field, the `subscript` value selects a specific element from it. +// A path that refers to a value nested under a map key or repeated field index will have a `subscript` value. +// The `field_type` field allows unambiguous resolution of a field even if descriptors are not available. +type FieldPathElement struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `field_number` is the field number this path element refers to. + FieldNumber *int32 `protobuf:"varint,1,opt,name=field_number,json=fieldNumber" json:"field_number,omitempty"` + // `field_name` contains the field name this path element refers to. + // This can be used to display a human-readable path even if the field number is unknown. + FieldName *string `protobuf:"bytes,2,opt,name=field_name,json=fieldName" json:"field_name,omitempty"` + // `field_type` specifies the type of this field. When using reflection, this value is not needed. + // + // This value is provided to make it possible to traverse unknown fields through wire data. + // When traversing wire data, be mindful of both packed[1] and delimited[2] encoding schemes. + // + // N.B.: Although groups are deprecated, the corresponding delimited encoding scheme is not, and + // can be explicitly used in Protocol Buffers 2023 Edition. + // + // [1]: https://protobuf.dev/programming-guides/encoding/#packed + // [2]: https://protobuf.dev/programming-guides/encoding/#groups + FieldType *descriptorpb.FieldDescriptorProto_Type `protobuf:"varint,3,opt,name=field_type,json=fieldType,enum=google.protobuf.FieldDescriptorProto_Type" json:"field_type,omitempty"` + // `key_type` specifies the map key type of this field. This value is useful when traversing + // unknown fields through wire data: specifically, it allows handling the differences between + // different integer encodings. + KeyType *descriptorpb.FieldDescriptorProto_Type `protobuf:"varint,4,opt,name=key_type,json=keyType,enum=google.protobuf.FieldDescriptorProto_Type" json:"key_type,omitempty"` + // `value_type` specifies map value type of this field. This is useful if you want to display a + // value inside unknown fields through wire data. + ValueType *descriptorpb.FieldDescriptorProto_Type `protobuf:"varint,5,opt,name=value_type,json=valueType,enum=google.protobuf.FieldDescriptorProto_Type" json:"value_type,omitempty"` + // `subscript` contains a repeated index or map key, if this path element nests into a repeated or map field. + // + // Types that are valid to be assigned to Subscript: + // + // *FieldPathElement_Index + // *FieldPathElement_BoolKey + // *FieldPathElement_IntKey + // *FieldPathElement_UintKey + // *FieldPathElement_StringKey + Subscript isFieldPathElement_Subscript `protobuf_oneof:"subscript"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FieldPathElement) Reset() { + *x = FieldPathElement{} + mi := &file_buf_validate_validate_proto_msgTypes[31] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FieldPathElement) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FieldPathElement) ProtoMessage() {} + +func (x *FieldPathElement) ProtoReflect() protoreflect.Message { + mi := &file_buf_validate_validate_proto_msgTypes[31] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FieldPathElement.ProtoReflect.Descriptor instead. +func (*FieldPathElement) Descriptor() ([]byte, []int) { + return file_buf_validate_validate_proto_rawDescGZIP(), []int{31} +} + +func (x *FieldPathElement) GetFieldNumber() int32 { + if x != nil && x.FieldNumber != nil { + return *x.FieldNumber + } + return 0 +} + +func (x *FieldPathElement) GetFieldName() string { + if x != nil && x.FieldName != nil { + return *x.FieldName + } + return "" +} + +func (x *FieldPathElement) GetFieldType() descriptorpb.FieldDescriptorProto_Type { + if x != nil && x.FieldType != nil { + return *x.FieldType + } + return descriptorpb.FieldDescriptorProto_Type(1) +} + +func (x *FieldPathElement) GetKeyType() descriptorpb.FieldDescriptorProto_Type { + if x != nil && x.KeyType != nil { + return *x.KeyType + } + return descriptorpb.FieldDescriptorProto_Type(1) +} + +func (x *FieldPathElement) GetValueType() descriptorpb.FieldDescriptorProto_Type { + if x != nil && x.ValueType != nil { + return *x.ValueType + } + return descriptorpb.FieldDescriptorProto_Type(1) +} + +func (x *FieldPathElement) GetSubscript() isFieldPathElement_Subscript { + if x != nil { + return x.Subscript + } + return nil +} + +func (x *FieldPathElement) GetIndex() uint64 { + if x != nil { + if x, ok := x.Subscript.(*FieldPathElement_Index); ok { + return x.Index + } + } + return 0 +} + +func (x *FieldPathElement) GetBoolKey() bool { + if x != nil { + if x, ok := x.Subscript.(*FieldPathElement_BoolKey); ok { + return x.BoolKey + } + } + return false +} + +func (x *FieldPathElement) GetIntKey() int64 { + if x != nil { + if x, ok := x.Subscript.(*FieldPathElement_IntKey); ok { + return x.IntKey + } + } + return 0 +} + +func (x *FieldPathElement) GetUintKey() uint64 { + if x != nil { + if x, ok := x.Subscript.(*FieldPathElement_UintKey); ok { + return x.UintKey + } + } + return 0 +} + +func (x *FieldPathElement) GetStringKey() string { + if x != nil { + if x, ok := x.Subscript.(*FieldPathElement_StringKey); ok { + return x.StringKey + } + } + return "" +} + +type isFieldPathElement_Subscript interface { + isFieldPathElement_Subscript() +} + +type FieldPathElement_Index struct { + // `index` specifies a 0-based index into a repeated field. + Index uint64 `protobuf:"varint,6,opt,name=index,oneof"` +} + +type FieldPathElement_BoolKey struct { + // `bool_key` specifies a map key of type bool. + BoolKey bool `protobuf:"varint,7,opt,name=bool_key,json=boolKey,oneof"` +} + +type FieldPathElement_IntKey struct { + // `int_key` specifies a map key of type int32, int64, sint32, sint64, sfixed32 or sfixed64. + IntKey int64 `protobuf:"varint,8,opt,name=int_key,json=intKey,oneof"` +} + +type FieldPathElement_UintKey struct { + // `uint_key` specifies a map key of type uint32, uint64, fixed32 or fixed64. + UintKey uint64 `protobuf:"varint,9,opt,name=uint_key,json=uintKey,oneof"` +} + +type FieldPathElement_StringKey struct { + // `string_key` specifies a map key of type string. + StringKey string `protobuf:"bytes,10,opt,name=string_key,json=stringKey,oneof"` +} + +func (*FieldPathElement_Index) isFieldPathElement_Subscript() {} + +func (*FieldPathElement_BoolKey) isFieldPathElement_Subscript() {} + +func (*FieldPathElement_IntKey) isFieldPathElement_Subscript() {} + +func (*FieldPathElement_UintKey) isFieldPathElement_Subscript() {} + +func (*FieldPathElement_StringKey) isFieldPathElement_Subscript() {} + +var file_buf_validate_validate_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.MessageOptions)(nil), + ExtensionType: (*MessageRules)(nil), + Field: 1159, + Name: "buf.validate.message", + Tag: "bytes,1159,opt,name=message", + Filename: "buf/validate/validate.proto", + }, + { + ExtendedType: (*descriptorpb.OneofOptions)(nil), + ExtensionType: (*OneofRules)(nil), + Field: 1159, + Name: "buf.validate.oneof", + Tag: "bytes,1159,opt,name=oneof", + Filename: "buf/validate/validate.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*FieldRules)(nil), + Field: 1159, + Name: "buf.validate.field", + Tag: "bytes,1159,opt,name=field", + Filename: "buf/validate/validate.proto", + }, + { + ExtendedType: (*descriptorpb.FieldOptions)(nil), + ExtensionType: (*PredefinedRules)(nil), + Field: 1160, + Name: "buf.validate.predefined", + Tag: "bytes,1160,opt,name=predefined", + Filename: "buf/validate/validate.proto", + }, +} + +// Extension fields to descriptorpb.MessageOptions. +var ( + // Rules specify the validations to be performed on this message. By default, + // no validation is performed against a message. + // + // optional buf.validate.MessageRules message = 1159; + E_Message = &file_buf_validate_validate_proto_extTypes[0] +) + +// Extension fields to descriptorpb.OneofOptions. +var ( + // Rules specify the validations to be performed on this oneof. By default, + // no validation is performed against a oneof. + // + // optional buf.validate.OneofRules oneof = 1159; + E_Oneof = &file_buf_validate_validate_proto_extTypes[1] +) + +// Extension fields to descriptorpb.FieldOptions. +var ( + // Rules specify the validations to be performed on this field. By default, + // no validation is performed against a field. + // + // optional buf.validate.FieldRules field = 1159; + E_Field = &file_buf_validate_validate_proto_extTypes[2] + // Specifies predefined rules. When extending a standard rule message, + // this adds additional CEL expressions that apply when the extension is used. + // + // ```proto + // + // extend buf.validate.Int32Rules { + // bool is_zero [(buf.validate.predefined).cel = { + // id: "int32.is_zero", + // message: "must be zero", + // expression: "!rule || this == 0", + // }]; + // } + // + // message Foo { + // int32 reserved = 1 [(buf.validate.field).int32.(is_zero) = true]; + // } + // + // ``` + // + // optional buf.validate.PredefinedRules predefined = 1160; + E_Predefined = &file_buf_validate_validate_proto_extTypes[3] +) + +var File_buf_validate_validate_proto protoreflect.FileDescriptor + +const file_buf_validate_validate_proto_rawDesc = "" + + "\n" + + "\x1bbuf/validate/validate.proto\x12\fbuf.validate\x1a google/protobuf/descriptor.proto\x1a\x1egoogle/protobuf/duration.proto\x1a google/protobuf/field_mask.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"P\n" + + "\x04Rule\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\x12\x1e\n" + + "\n" + + "expression\x18\x03 \x01(\tR\n" + + "expression\"\xa1\x01\n" + + "\fMessageRules\x12%\n" + + "\x0ecel_expression\x18\x05 \x03(\tR\rcelExpression\x12$\n" + + "\x03cel\x18\x03 \x03(\v2\x12.buf.validate.RuleR\x03cel\x124\n" + + "\x05oneof\x18\x04 \x03(\v2\x1e.buf.validate.MessageOneofRuleR\x05oneofJ\x04\b\x01\x10\x02R\bdisabled\"F\n" + + "\x10MessageOneofRule\x12\x16\n" + + "\x06fields\x18\x01 \x03(\tR\x06fields\x12\x1a\n" + + "\brequired\x18\x02 \x01(\bR\brequired\"(\n" + + "\n" + + "OneofRules\x12\x1a\n" + + "\brequired\x18\x01 \x01(\bR\brequired\"\xe3\n" + + "\n" + + "\n" + + "FieldRules\x12%\n" + + "\x0ecel_expression\x18\x1d \x03(\tR\rcelExpression\x12$\n" + + "\x03cel\x18\x17 \x03(\v2\x12.buf.validate.RuleR\x03cel\x12\x1a\n" + + "\brequired\x18\x19 \x01(\bR\brequired\x12,\n" + + "\x06ignore\x18\x1b \x01(\x0e2\x14.buf.validate.IgnoreR\x06ignore\x120\n" + + "\x05float\x18\x01 \x01(\v2\x18.buf.validate.FloatRulesH\x00R\x05float\x123\n" + + "\x06double\x18\x02 \x01(\v2\x19.buf.validate.DoubleRulesH\x00R\x06double\x120\n" + + "\x05int32\x18\x03 \x01(\v2\x18.buf.validate.Int32RulesH\x00R\x05int32\x120\n" + + "\x05int64\x18\x04 \x01(\v2\x18.buf.validate.Int64RulesH\x00R\x05int64\x123\n" + + "\x06uint32\x18\x05 \x01(\v2\x19.buf.validate.UInt32RulesH\x00R\x06uint32\x123\n" + + "\x06uint64\x18\x06 \x01(\v2\x19.buf.validate.UInt64RulesH\x00R\x06uint64\x123\n" + + "\x06sint32\x18\a \x01(\v2\x19.buf.validate.SInt32RulesH\x00R\x06sint32\x123\n" + + "\x06sint64\x18\b \x01(\v2\x19.buf.validate.SInt64RulesH\x00R\x06sint64\x126\n" + + "\afixed32\x18\t \x01(\v2\x1a.buf.validate.Fixed32RulesH\x00R\afixed32\x126\n" + + "\afixed64\x18\n" + + " \x01(\v2\x1a.buf.validate.Fixed64RulesH\x00R\afixed64\x129\n" + + "\bsfixed32\x18\v \x01(\v2\x1b.buf.validate.SFixed32RulesH\x00R\bsfixed32\x129\n" + + "\bsfixed64\x18\f \x01(\v2\x1b.buf.validate.SFixed64RulesH\x00R\bsfixed64\x12-\n" + + "\x04bool\x18\r \x01(\v2\x17.buf.validate.BoolRulesH\x00R\x04bool\x123\n" + + "\x06string\x18\x0e \x01(\v2\x19.buf.validate.StringRulesH\x00R\x06string\x120\n" + + "\x05bytes\x18\x0f \x01(\v2\x18.buf.validate.BytesRulesH\x00R\x05bytes\x12-\n" + + "\x04enum\x18\x10 \x01(\v2\x17.buf.validate.EnumRulesH\x00R\x04enum\x129\n" + + "\brepeated\x18\x12 \x01(\v2\x1b.buf.validate.RepeatedRulesH\x00R\brepeated\x12*\n" + + "\x03map\x18\x13 \x01(\v2\x16.buf.validate.MapRulesH\x00R\x03map\x12*\n" + + "\x03any\x18\x14 \x01(\v2\x16.buf.validate.AnyRulesH\x00R\x03any\x129\n" + + "\bduration\x18\x15 \x01(\v2\x1b.buf.validate.DurationRulesH\x00R\bduration\x12=\n" + + "\n" + + "field_mask\x18\x1c \x01(\v2\x1c.buf.validate.FieldMaskRulesH\x00R\tfieldMask\x12<\n" + + "\ttimestamp\x18\x16 \x01(\v2\x1c.buf.validate.TimestampRulesH\x00R\ttimestampB\x06\n" + + "\x04typeJ\x04\b\x18\x10\x19J\x04\b\x1a\x10\x1bR\askippedR\fignore_empty\"Z\n" + + "\x0fPredefinedRules\x12$\n" + + "\x03cel\x18\x01 \x03(\v2\x12.buf.validate.RuleR\x03celJ\x04\b\x18\x10\x19J\x04\b\x1a\x10\x1bR\askippedR\fignore_empty\"\xae\x17\n" + + "\n" + + "FloatRules\x12\x84\x01\n" + + "\x05const\x18\x01 \x01(\x02Bn\xc2Hk\n" + + "i\n" + + "\vfloat.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x9d\x01\n" + + "\x02lt\x18\x02 \x01(\x02B\x8a\x01\xc2H\x86\x01\n" + + "\x83\x01\n" + + "\bfloat.lt\x1aw!has(rules.gte) && !has(rules.gt) && (this.isNan() || this >= rules.lt)? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\xae\x01\n" + + "\x03lte\x18\x03 \x01(\x02B\x99\x01\xc2H\x95\x01\n" + + "\x92\x01\n" + + "\tfloat.lte\x1a\x84\x01!has(rules.gte) && !has(rules.gt) && (this.isNan() || this > rules.lte)? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\xd4\a\n" + + "\x02gt\x18\x04 \x01(\x02B\xc1\a\xc2H\xbd\a\n" + + "\x86\x01\n" + + "\bfloat.gt\x1az!has(rules.lt) && !has(rules.lte) && (this.isNan() || this <= rules.gt)? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xbd\x01\n" + + "\vfloat.gt_lt\x1a\xad\x01has(rules.lt) && rules.lt >= rules.gt && (this.isNan() || this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xc7\x01\n" + + "\x15float.gt_lt_exclusive\x1a\xad\x01has(rules.lt) && rules.lt < rules.gt && (this.isNan() || (rules.lt <= this && this <= rules.gt))? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xcd\x01\n" + + "\ffloat.gt_lte\x1a\xbc\x01has(rules.lte) && rules.lte >= rules.gt && (this.isNan() || this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xd7\x01\n" + + "\x16float.gt_lte_exclusive\x1a\xbc\x01has(rules.lte) && rules.lte < rules.gt && (this.isNan() || (rules.lte < this && this <= rules.gt))? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xa1\b\n" + + "\x03gte\x18\x05 \x01(\x02B\x8c\b\xc2H\x88\b\n" + + "\x95\x01\n" + + "\tfloat.gte\x1a\x87\x01!has(rules.lt) && !has(rules.lte) && (this.isNan() || this < rules.gte)? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xcc\x01\n" + + "\ffloat.gte_lt\x1a\xbb\x01has(rules.lt) && rules.lt >= rules.gte && (this.isNan() || this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xd6\x01\n" + + "\x16float.gte_lt_exclusive\x1a\xbb\x01has(rules.lt) && rules.lt < rules.gte && (this.isNan() || (rules.lt <= this && this < rules.gte))? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xdc\x01\n" + + "\rfloat.gte_lte\x1a\xca\x01has(rules.lte) && rules.lte >= rules.gte && (this.isNan() || this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xe6\x01\n" + + "\x17float.gte_lte_exclusive\x1a\xca\x01has(rules.lte) && rules.lte < rules.gte && (this.isNan() || (rules.lte < this && this < rules.gte))? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12}\n" + + "\x02in\x18\x06 \x03(\x02Bm\xc2Hj\n" + + "h\n" + + "\bfloat.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12w\n" + + "\x06not_in\x18\a \x03(\x02B`\xc2H]\n" + + "[\n" + + "\ffloat.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x12w\n" + + "\x06finite\x18\b \x01(\bB_\xc2H\\\n" + + "Z\n" + + "\ffloat.finite\x1aJrules.finite ? (this.isNan() || this.isInf() ? 'must be finite' : '') : ''R\x06finite\x124\n" + + "\aexample\x18\t \x03(\x02B\x1a\xc2H\x17\n" + + "\x15\n" + + "\rfloat.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\xc0\x17\n" + + "\vDoubleRules\x12\x85\x01\n" + + "\x05const\x18\x01 \x01(\x01Bo\xc2Hl\n" + + "j\n" + + "\fdouble.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x9e\x01\n" + + "\x02lt\x18\x02 \x01(\x01B\x8b\x01\xc2H\x87\x01\n" + + "\x84\x01\n" + + "\tdouble.lt\x1aw!has(rules.gte) && !has(rules.gt) && (this.isNan() || this >= rules.lt)? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\xaf\x01\n" + + "\x03lte\x18\x03 \x01(\x01B\x9a\x01\xc2H\x96\x01\n" + + "\x93\x01\n" + + "\n" + + "double.lte\x1a\x84\x01!has(rules.gte) && !has(rules.gt) && (this.isNan() || this > rules.lte)? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\xd9\a\n" + + "\x02gt\x18\x04 \x01(\x01B\xc6\a\xc2H\xc2\a\n" + + "\x87\x01\n" + + "\tdouble.gt\x1az!has(rules.lt) && !has(rules.lte) && (this.isNan() || this <= rules.gt)? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xbe\x01\n" + + "\fdouble.gt_lt\x1a\xad\x01has(rules.lt) && rules.lt >= rules.gt && (this.isNan() || this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xc8\x01\n" + + "\x16double.gt_lt_exclusive\x1a\xad\x01has(rules.lt) && rules.lt < rules.gt && (this.isNan() || (rules.lt <= this && this <= rules.gt))? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xce\x01\n" + + "\rdouble.gt_lte\x1a\xbc\x01has(rules.lte) && rules.lte >= rules.gt && (this.isNan() || this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xd8\x01\n" + + "\x17double.gt_lte_exclusive\x1a\xbc\x01has(rules.lte) && rules.lte < rules.gt && (this.isNan() || (rules.lte < this && this <= rules.gt))? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xa6\b\n" + + "\x03gte\x18\x05 \x01(\x01B\x91\b\xc2H\x8d\b\n" + + "\x96\x01\n" + + "\n" + + "double.gte\x1a\x87\x01!has(rules.lt) && !has(rules.lte) && (this.isNan() || this < rules.gte)? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xcd\x01\n" + + "\rdouble.gte_lt\x1a\xbb\x01has(rules.lt) && rules.lt >= rules.gte && (this.isNan() || this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xd7\x01\n" + + "\x17double.gte_lt_exclusive\x1a\xbb\x01has(rules.lt) && rules.lt < rules.gte && (this.isNan() || (rules.lt <= this && this < rules.gte))? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xdd\x01\n" + + "\x0edouble.gte_lte\x1a\xca\x01has(rules.lte) && rules.lte >= rules.gte && (this.isNan() || this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xe7\x01\n" + + "\x18double.gte_lte_exclusive\x1a\xca\x01has(rules.lte) && rules.lte < rules.gte && (this.isNan() || (rules.lte < this && this < rules.gte))? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12~\n" + + "\x02in\x18\x06 \x03(\x01Bn\xc2Hk\n" + + "i\n" + + "\tdouble.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12x\n" + + "\x06not_in\x18\a \x03(\x01Ba\xc2H^\n" + + "\\\n" + + "\rdouble.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x12x\n" + + "\x06finite\x18\b \x01(\bB`\xc2H]\n" + + "[\n" + + "\rdouble.finite\x1aJrules.finite ? (this.isNan() || this.isInf() ? 'must be finite' : '') : ''R\x06finite\x125\n" + + "\aexample\x18\t \x03(\x01B\x1b\xc2H\x18\n" + + "\x16\n" + + "\x0edouble.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\xde\x14\n" + + "\n" + + "Int32Rules\x12\x84\x01\n" + + "\x05const\x18\x01 \x01(\x05Bn\xc2Hk\n" + + "i\n" + + "\vint32.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x88\x01\n" + + "\x02lt\x18\x02 \x01(\x05Bv\xc2Hs\n" + + "q\n" + + "\bint32.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9a\x01\n" + + "\x03lte\x18\x03 \x01(\x05B\x85\x01\xc2H\x81\x01\n" + + "\x7f\n" + + "\tint32.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\xfd\x06\n" + + "\x02gt\x18\x04 \x01(\x05B\xea\x06\xc2H\xe6\x06\n" + + "t\n" + + "\bint32.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xad\x01\n" + + "\vint32.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb5\x01\n" + + "\x15int32.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xbd\x01\n" + + "\fint32.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc5\x01\n" + + "\x16int32.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xca\a\n" + + "\x03gte\x18\x05 \x01(\x05B\xb5\a\xc2H\xb1\a\n" + + "\x82\x01\n" + + "\tint32.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbc\x01\n" + + "\fint32.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc4\x01\n" + + "\x16int32.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcc\x01\n" + + "\rint32.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd4\x01\n" + + "\x17int32.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12}\n" + + "\x02in\x18\x06 \x03(\x05Bm\xc2Hj\n" + + "h\n" + + "\bint32.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12w\n" + + "\x06not_in\x18\a \x03(\x05B`\xc2H]\n" + + "[\n" + + "\fint32.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x124\n" + + "\aexample\x18\b \x03(\x05B\x1a\xc2H\x17\n" + + "\x15\n" + + "\rint32.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\xde\x14\n" + + "\n" + + "Int64Rules\x12\x84\x01\n" + + "\x05const\x18\x01 \x01(\x03Bn\xc2Hk\n" + + "i\n" + + "\vint64.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x88\x01\n" + + "\x02lt\x18\x02 \x01(\x03Bv\xc2Hs\n" + + "q\n" + + "\bint64.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9a\x01\n" + + "\x03lte\x18\x03 \x01(\x03B\x85\x01\xc2H\x81\x01\n" + + "\x7f\n" + + "\tint64.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\xfd\x06\n" + + "\x02gt\x18\x04 \x01(\x03B\xea\x06\xc2H\xe6\x06\n" + + "t\n" + + "\bint64.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xad\x01\n" + + "\vint64.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb5\x01\n" + + "\x15int64.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xbd\x01\n" + + "\fint64.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc5\x01\n" + + "\x16int64.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xca\a\n" + + "\x03gte\x18\x05 \x01(\x03B\xb5\a\xc2H\xb1\a\n" + + "\x82\x01\n" + + "\tint64.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbc\x01\n" + + "\fint64.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc4\x01\n" + + "\x16int64.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcc\x01\n" + + "\rint64.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd4\x01\n" + + "\x17int64.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12}\n" + + "\x02in\x18\x06 \x03(\x03Bm\xc2Hj\n" + + "h\n" + + "\bint64.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12w\n" + + "\x06not_in\x18\a \x03(\x03B`\xc2H]\n" + + "[\n" + + "\fint64.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x124\n" + + "\aexample\x18\t \x03(\x03B\x1a\xc2H\x17\n" + + "\x15\n" + + "\rint64.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\xf0\x14\n" + + "\vUInt32Rules\x12\x85\x01\n" + + "\x05const\x18\x01 \x01(\rBo\xc2Hl\n" + + "j\n" + + "\fuint32.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x89\x01\n" + + "\x02lt\x18\x02 \x01(\rBw\xc2Ht\n" + + "r\n" + + "\tuint32.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9c\x01\n" + + "\x03lte\x18\x03 \x01(\rB\x87\x01\xc2H\x83\x01\n" + + "\x80\x01\n" + + "\n" + + "uint32.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\x82\a\n" + + "\x02gt\x18\x04 \x01(\rB\xef\x06\xc2H\xeb\x06\n" + + "u\n" + + "\tuint32.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xae\x01\n" + + "\fuint32.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb6\x01\n" + + "\x16uint32.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xbe\x01\n" + + "\ruint32.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc6\x01\n" + + "\x17uint32.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xcf\a\n" + + "\x03gte\x18\x05 \x01(\rB\xba\a\xc2H\xb6\a\n" + + "\x83\x01\n" + + "\n" + + "uint32.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbd\x01\n" + + "\ruint32.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc5\x01\n" + + "\x17uint32.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcd\x01\n" + + "\x0euint32.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd5\x01\n" + + "\x18uint32.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12~\n" + + "\x02in\x18\x06 \x03(\rBn\xc2Hk\n" + + "i\n" + + "\tuint32.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12x\n" + + "\x06not_in\x18\a \x03(\rBa\xc2H^\n" + + "\\\n" + + "\ruint32.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x125\n" + + "\aexample\x18\b \x03(\rB\x1b\xc2H\x18\n" + + "\x16\n" + + "\x0euint32.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\xf0\x14\n" + + "\vUInt64Rules\x12\x85\x01\n" + + "\x05const\x18\x01 \x01(\x04Bo\xc2Hl\n" + + "j\n" + + "\fuint64.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x89\x01\n" + + "\x02lt\x18\x02 \x01(\x04Bw\xc2Ht\n" + + "r\n" + + "\tuint64.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9c\x01\n" + + "\x03lte\x18\x03 \x01(\x04B\x87\x01\xc2H\x83\x01\n" + + "\x80\x01\n" + + "\n" + + "uint64.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\x82\a\n" + + "\x02gt\x18\x04 \x01(\x04B\xef\x06\xc2H\xeb\x06\n" + + "u\n" + + "\tuint64.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xae\x01\n" + + "\fuint64.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb6\x01\n" + + "\x16uint64.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xbe\x01\n" + + "\ruint64.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc6\x01\n" + + "\x17uint64.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xcf\a\n" + + "\x03gte\x18\x05 \x01(\x04B\xba\a\xc2H\xb6\a\n" + + "\x83\x01\n" + + "\n" + + "uint64.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbd\x01\n" + + "\ruint64.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc5\x01\n" + + "\x17uint64.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcd\x01\n" + + "\x0euint64.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd5\x01\n" + + "\x18uint64.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12~\n" + + "\x02in\x18\x06 \x03(\x04Bn\xc2Hk\n" + + "i\n" + + "\tuint64.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12x\n" + + "\x06not_in\x18\a \x03(\x04Ba\xc2H^\n" + + "\\\n" + + "\ruint64.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x125\n" + + "\aexample\x18\b \x03(\x04B\x1b\xc2H\x18\n" + + "\x16\n" + + "\x0euint64.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\xf0\x14\n" + + "\vSInt32Rules\x12\x85\x01\n" + + "\x05const\x18\x01 \x01(\x11Bo\xc2Hl\n" + + "j\n" + + "\fsint32.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x89\x01\n" + + "\x02lt\x18\x02 \x01(\x11Bw\xc2Ht\n" + + "r\n" + + "\tsint32.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9c\x01\n" + + "\x03lte\x18\x03 \x01(\x11B\x87\x01\xc2H\x83\x01\n" + + "\x80\x01\n" + + "\n" + + "sint32.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\x82\a\n" + + "\x02gt\x18\x04 \x01(\x11B\xef\x06\xc2H\xeb\x06\n" + + "u\n" + + "\tsint32.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xae\x01\n" + + "\fsint32.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb6\x01\n" + + "\x16sint32.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xbe\x01\n" + + "\rsint32.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc6\x01\n" + + "\x17sint32.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xcf\a\n" + + "\x03gte\x18\x05 \x01(\x11B\xba\a\xc2H\xb6\a\n" + + "\x83\x01\n" + + "\n" + + "sint32.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbd\x01\n" + + "\rsint32.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc5\x01\n" + + "\x17sint32.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcd\x01\n" + + "\x0esint32.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd5\x01\n" + + "\x18sint32.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12~\n" + + "\x02in\x18\x06 \x03(\x11Bn\xc2Hk\n" + + "i\n" + + "\tsint32.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12x\n" + + "\x06not_in\x18\a \x03(\x11Ba\xc2H^\n" + + "\\\n" + + "\rsint32.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x125\n" + + "\aexample\x18\b \x03(\x11B\x1b\xc2H\x18\n" + + "\x16\n" + + "\x0esint32.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\xf0\x14\n" + + "\vSInt64Rules\x12\x85\x01\n" + + "\x05const\x18\x01 \x01(\x12Bo\xc2Hl\n" + + "j\n" + + "\fsint64.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x89\x01\n" + + "\x02lt\x18\x02 \x01(\x12Bw\xc2Ht\n" + + "r\n" + + "\tsint64.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9c\x01\n" + + "\x03lte\x18\x03 \x01(\x12B\x87\x01\xc2H\x83\x01\n" + + "\x80\x01\n" + + "\n" + + "sint64.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\x82\a\n" + + "\x02gt\x18\x04 \x01(\x12B\xef\x06\xc2H\xeb\x06\n" + + "u\n" + + "\tsint64.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xae\x01\n" + + "\fsint64.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb6\x01\n" + + "\x16sint64.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xbe\x01\n" + + "\rsint64.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc6\x01\n" + + "\x17sint64.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xcf\a\n" + + "\x03gte\x18\x05 \x01(\x12B\xba\a\xc2H\xb6\a\n" + + "\x83\x01\n" + + "\n" + + "sint64.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbd\x01\n" + + "\rsint64.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc5\x01\n" + + "\x17sint64.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcd\x01\n" + + "\x0esint64.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd5\x01\n" + + "\x18sint64.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12~\n" + + "\x02in\x18\x06 \x03(\x12Bn\xc2Hk\n" + + "i\n" + + "\tsint64.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12x\n" + + "\x06not_in\x18\a \x03(\x12Ba\xc2H^\n" + + "\\\n" + + "\rsint64.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x125\n" + + "\aexample\x18\b \x03(\x12B\x1b\xc2H\x18\n" + + "\x16\n" + + "\x0esint64.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\x81\x15\n" + + "\fFixed32Rules\x12\x86\x01\n" + + "\x05const\x18\x01 \x01(\aBp\xc2Hm\n" + + "k\n" + + "\rfixed32.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x8a\x01\n" + + "\x02lt\x18\x02 \x01(\aBx\xc2Hu\n" + + "s\n" + + "\n" + + "fixed32.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9d\x01\n" + + "\x03lte\x18\x03 \x01(\aB\x88\x01\xc2H\x84\x01\n" + + "\x81\x01\n" + + "\vfixed32.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\x87\a\n" + + "\x02gt\x18\x04 \x01(\aB\xf4\x06\xc2H\xf0\x06\n" + + "v\n" + + "\n" + + "fixed32.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xaf\x01\n" + + "\rfixed32.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb7\x01\n" + + "\x17fixed32.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xbf\x01\n" + + "\x0efixed32.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc7\x01\n" + + "\x18fixed32.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xd4\a\n" + + "\x03gte\x18\x05 \x01(\aB\xbf\a\xc2H\xbb\a\n" + + "\x84\x01\n" + + "\vfixed32.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbe\x01\n" + + "\x0efixed32.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc6\x01\n" + + "\x18fixed32.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xce\x01\n" + + "\x0ffixed32.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd6\x01\n" + + "\x19fixed32.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12\x7f\n" + + "\x02in\x18\x06 \x03(\aBo\xc2Hl\n" + + "j\n" + + "\n" + + "fixed32.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12y\n" + + "\x06not_in\x18\a \x03(\aBb\xc2H_\n" + + "]\n" + + "\x0efixed32.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x126\n" + + "\aexample\x18\b \x03(\aB\x1c\xc2H\x19\n" + + "\x17\n" + + "\x0ffixed32.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\x81\x15\n" + + "\fFixed64Rules\x12\x86\x01\n" + + "\x05const\x18\x01 \x01(\x06Bp\xc2Hm\n" + + "k\n" + + "\rfixed64.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x8a\x01\n" + + "\x02lt\x18\x02 \x01(\x06Bx\xc2Hu\n" + + "s\n" + + "\n" + + "fixed64.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9d\x01\n" + + "\x03lte\x18\x03 \x01(\x06B\x88\x01\xc2H\x84\x01\n" + + "\x81\x01\n" + + "\vfixed64.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\x87\a\n" + + "\x02gt\x18\x04 \x01(\x06B\xf4\x06\xc2H\xf0\x06\n" + + "v\n" + + "\n" + + "fixed64.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xaf\x01\n" + + "\rfixed64.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb7\x01\n" + + "\x17fixed64.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xbf\x01\n" + + "\x0efixed64.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc7\x01\n" + + "\x18fixed64.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xd4\a\n" + + "\x03gte\x18\x05 \x01(\x06B\xbf\a\xc2H\xbb\a\n" + + "\x84\x01\n" + + "\vfixed64.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbe\x01\n" + + "\x0efixed64.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc6\x01\n" + + "\x18fixed64.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xce\x01\n" + + "\x0ffixed64.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd6\x01\n" + + "\x19fixed64.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12\x7f\n" + + "\x02in\x18\x06 \x03(\x06Bo\xc2Hl\n" + + "j\n" + + "\n" + + "fixed64.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12y\n" + + "\x06not_in\x18\a \x03(\x06Bb\xc2H_\n" + + "]\n" + + "\x0efixed64.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x126\n" + + "\aexample\x18\b \x03(\x06B\x1c\xc2H\x19\n" + + "\x17\n" + + "\x0ffixed64.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\x93\x15\n" + + "\rSFixed32Rules\x12\x87\x01\n" + + "\x05const\x18\x01 \x01(\x0fBq\xc2Hn\n" + + "l\n" + + "\x0esfixed32.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x8b\x01\n" + + "\x02lt\x18\x02 \x01(\x0fBy\xc2Hv\n" + + "t\n" + + "\vsfixed32.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9e\x01\n" + + "\x03lte\x18\x03 \x01(\x0fB\x89\x01\xc2H\x85\x01\n" + + "\x82\x01\n" + + "\fsfixed32.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\x8c\a\n" + + "\x02gt\x18\x04 \x01(\x0fB\xf9\x06\xc2H\xf5\x06\n" + + "w\n" + + "\vsfixed32.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xb0\x01\n" + + "\x0esfixed32.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb8\x01\n" + + "\x18sfixed32.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xc0\x01\n" + + "\x0fsfixed32.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc8\x01\n" + + "\x19sfixed32.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xd9\a\n" + + "\x03gte\x18\x05 \x01(\x0fB\xc4\a\xc2H\xc0\a\n" + + "\x85\x01\n" + + "\fsfixed32.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbf\x01\n" + + "\x0fsfixed32.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc7\x01\n" + + "\x19sfixed32.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcf\x01\n" + + "\x10sfixed32.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd7\x01\n" + + "\x1asfixed32.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12\x80\x01\n" + + "\x02in\x18\x06 \x03(\x0fBp\xc2Hm\n" + + "k\n" + + "\vsfixed32.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12z\n" + + "\x06not_in\x18\a \x03(\x0fBc\xc2H`\n" + + "^\n" + + "\x0fsfixed32.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x127\n" + + "\aexample\x18\b \x03(\x0fB\x1d\xc2H\x1a\n" + + "\x18\n" + + "\x10sfixed32.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\x93\x15\n" + + "\rSFixed64Rules\x12\x87\x01\n" + + "\x05const\x18\x01 \x01(\x10Bq\xc2Hn\n" + + "l\n" + + "\x0esfixed64.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\x8b\x01\n" + + "\x02lt\x18\x02 \x01(\x10By\xc2Hv\n" + + "t\n" + + "\vsfixed64.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\x9e\x01\n" + + "\x03lte\x18\x03 \x01(\x10B\x89\x01\xc2H\x85\x01\n" + + "\x82\x01\n" + + "\fsfixed64.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\x8c\a\n" + + "\x02gt\x18\x04 \x01(\x10B\xf9\x06\xc2H\xf5\x06\n" + + "w\n" + + "\vsfixed64.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xb0\x01\n" + + "\x0esfixed64.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb8\x01\n" + + "\x18sfixed64.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xc0\x01\n" + + "\x0fsfixed64.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc8\x01\n" + + "\x19sfixed64.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xd9\a\n" + + "\x03gte\x18\x05 \x01(\x10B\xc4\a\xc2H\xc0\a\n" + + "\x85\x01\n" + + "\fsfixed64.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbf\x01\n" + + "\x0fsfixed64.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc7\x01\n" + + "\x19sfixed64.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcf\x01\n" + + "\x10sfixed64.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd7\x01\n" + + "\x1asfixed64.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12\x80\x01\n" + + "\x02in\x18\x06 \x03(\x10Bp\xc2Hm\n" + + "k\n" + + "\vsfixed64.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12z\n" + + "\x06not_in\x18\a \x03(\x10Bc\xc2H`\n" + + "^\n" + + "\x0fsfixed64.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x127\n" + + "\aexample\x18\b \x03(\x10B\x1d\xc2H\x1a\n" + + "\x18\n" + + "\x10sfixed64.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\xd1\x01\n" + + "\tBoolRules\x12\x83\x01\n" + + "\x05const\x18\x01 \x01(\bBm\xc2Hj\n" + + "h\n" + + "\n" + + "bool.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x123\n" + + "\aexample\x18\x02 \x03(\bB\x19\xc2H\x16\n" + + "\x14\n" + + "\fbool.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02\"\xcf?\n" + + "\vStringRules\x12\x87\x01\n" + + "\x05const\x18\x01 \x01(\tBq\xc2Hn\n" + + "l\n" + + "\fstring.const\x1a\\this != getField(rules, 'const') ? 'must equal `%s`'.format([getField(rules, 'const')]) : ''R\x05const\x12v\n" + + "\x03len\x18\x13 \x01(\x04Bd\xc2Ha\n" + + "_\n" + + "\n" + + "string.len\x1aQuint(this.size()) != rules.len ? 'must be %s characters'.format([rules.len]) : ''R\x03len\x12\x91\x01\n" + + "\amin_len\x18\x02 \x01(\x04Bx\xc2Hu\n" + + "s\n" + + "\x0estring.min_len\x1aauint(this.size()) < rules.min_len ? 'must be at least %s characters'.format([rules.min_len]) : ''R\x06minLen\x12\x90\x01\n" + + "\amax_len\x18\x03 \x01(\x04Bw\xc2Ht\n" + + "r\n" + + "\x0estring.max_len\x1a`uint(this.size()) > rules.max_len ? 'must be at most %s characters'.format([rules.max_len]) : ''R\x06maxLen\x12\x95\x01\n" + + "\tlen_bytes\x18\x14 \x01(\x04Bx\xc2Hu\n" + + "s\n" + + "\x10string.len_bytes\x1a_uint(bytes(this).size()) != rules.len_bytes ? 'must be %s bytes'.format([rules.len_bytes]) : ''R\blenBytes\x12\x9e\x01\n" + + "\tmin_bytes\x18\x04 \x01(\x04B\x80\x01\xc2H}\n" + + "{\n" + + "\x10string.min_bytes\x1aguint(bytes(this).size()) < rules.min_bytes ? 'must be at least %s bytes'.format([rules.min_bytes]) : ''R\bminBytes\x12\x9c\x01\n" + + "\tmax_bytes\x18\x05 \x01(\x04B\x7f\xc2H|\n" + + "z\n" + + "\x10string.max_bytes\x1afuint(bytes(this).size()) > rules.max_bytes ? 'must be at most %s bytes'.format([rules.max_bytes]) : ''R\bmaxBytes\x12\x90\x01\n" + + "\apattern\x18\x06 \x01(\tBv\xc2Hs\n" + + "q\n" + + "\x0estring.pattern\x1a_!this.matches(rules.pattern) ? 'does not match regex pattern `%s`'.format([rules.pattern]) : ''R\apattern\x12\x86\x01\n" + + "\x06prefix\x18\a \x01(\tBn\xc2Hk\n" + + "i\n" + + "\rstring.prefix\x1aX!this.startsWith(rules.prefix) ? 'does not have prefix `%s`'.format([rules.prefix]) : ''R\x06prefix\x12\x84\x01\n" + + "\x06suffix\x18\b \x01(\tBl\xc2Hi\n" + + "g\n" + + "\rstring.suffix\x1aV!this.endsWith(rules.suffix) ? 'does not have suffix `%s`'.format([rules.suffix]) : ''R\x06suffix\x12\x94\x01\n" + + "\bcontains\x18\t \x01(\tBx\xc2Hu\n" + + "s\n" + + "\x0fstring.contains\x1a`!this.contains(rules.contains) ? 'does not contain substring `%s`'.format([rules.contains]) : ''R\bcontains\x12\x9e\x01\n" + + "\fnot_contains\x18\x17 \x01(\tB{\xc2Hx\n" + + "v\n" + + "\x13string.not_contains\x1a_this.contains(rules.not_contains) ? 'contains substring `%s`'.format([rules.not_contains]) : ''R\vnotContains\x12~\n" + + "\x02in\x18\n" + + " \x03(\tBn\xc2Hk\n" + + "i\n" + + "\tstring.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12x\n" + + "\x06not_in\x18\v \x03(\tBa\xc2H^\n" + + "\\\n" + + "\rstring.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x12\xe0\x01\n" + + "\x05email\x18\f \x01(\bB\xc7\x01\xc2H\xc3\x01\n" + + "[\n" + + "\fstring.email\x12\x1dmust be a valid email address\x1a,!rules.email || this == '' || this.isEmail()\n" + + "d\n" + + "\x12string.email_empty\x122value is empty, which is not a valid email address\x1a\x1a!rules.email || this != ''H\x00R\x05email\x12\xeb\x01\n" + + "\bhostname\x18\r \x01(\bB\xcc\x01\xc2H\xc8\x01\n" + + "_\n" + + "\x0fstring.hostname\x12\x18must be a valid hostname\x1a2!rules.hostname || this == '' || this.isHostname()\n" + + "e\n" + + "\x15string.hostname_empty\x12-value is empty, which is not a valid hostname\x1a\x1d!rules.hostname || this != ''H\x00R\bhostname\x12\xc5\x01\n" + + "\x02ip\x18\x0e \x01(\bB\xb2\x01\xc2H\xae\x01\n" + + "O\n" + + "\tstring.ip\x12\x1amust be a valid IP address\x1a&!rules.ip || this == '' || this.isIp()\n" + + "[\n" + + "\x0fstring.ip_empty\x12/value is empty, which is not a valid IP address\x1a\x17!rules.ip || this != ''H\x00R\x02ip\x12\xd6\x01\n" + + "\x04ipv4\x18\x0f \x01(\bB\xbf\x01\xc2H\xbb\x01\n" + + "V\n" + + "\vstring.ipv4\x12\x1cmust be a valid IPv4 address\x1a)!rules.ipv4 || this == '' || this.isIp(4)\n" + + "a\n" + + "\x11string.ipv4_empty\x121value is empty, which is not a valid IPv4 address\x1a\x19!rules.ipv4 || this != ''H\x00R\x04ipv4\x12\xd6\x01\n" + + "\x04ipv6\x18\x10 \x01(\bB\xbf\x01\xc2H\xbb\x01\n" + + "V\n" + + "\vstring.ipv6\x12\x1cmust be a valid IPv6 address\x1a)!rules.ipv6 || this == '' || this.isIp(6)\n" + + "a\n" + + "\x11string.ipv6_empty\x121value is empty, which is not a valid IPv6 address\x1a\x19!rules.ipv6 || this != ''H\x00R\x04ipv6\x12\xbe\x01\n" + + "\x03uri\x18\x11 \x01(\bB\xa9\x01\xc2H\xa5\x01\n" + + "K\n" + + "\n" + + "string.uri\x12\x13must be a valid URI\x1a(!rules.uri || this == '' || this.isUri()\n" + + "V\n" + + "\x10string.uri_empty\x12(value is empty, which is not a valid URI\x1a\x18!rules.uri || this != ''H\x00R\x03uri\x12r\n" + + "\auri_ref\x18\x12 \x01(\bBW\xc2HT\n" + + "R\n" + + "\x0estring.uri_ref\x12\x1dmust be a valid URI Reference\x1a!!rules.uri_ref || this.isUriRef()H\x00R\x06uriRef\x12\x92\x02\n" + + "\aaddress\x18\x15 \x01(\bB\xf5\x01\xc2H\xf1\x01\n" + + "{\n" + + "\x0estring.address\x12'must be a valid hostname, or ip address\x1a@!rules.address || this == '' || this.isHostname() || this.isIp()\n" + + "r\n" + + "\x14string.address_empty\x12!rules.ipv4_with_prefixlen || this == '' || this.isIpPrefix(4)\n" + + "\x92\x01\n" + + " string.ipv4_with_prefixlen_empty\x12Dvalue is empty, which is not a valid IPv4 address with prefix length\x1a(!rules.ipv4_with_prefixlen || this != ''H\x00R\x11ipv4WithPrefixlen\x12\xdc\x02\n" + + "\x13ipv6_with_prefixlen\x18\x1c \x01(\bB\xa9\x02\xc2H\xa5\x02\n" + + "\x8d\x01\n" + + "\x1astring.ipv6_with_prefixlen\x12/must be a valid IPv6 address with prefix length\x1a>!rules.ipv6_with_prefixlen || this == '' || this.isIpPrefix(6)\n" + + "\x92\x01\n" + + " string.ipv6_with_prefixlen_empty\x12Dvalue is empty, which is not a valid IPv6 address with prefix length\x1a(!rules.ipv6_with_prefixlen || this != ''H\x00R\x11ipv6WithPrefixlen\x12\xf6\x01\n" + + "\tip_prefix\x18\x1d \x01(\bB\xd6\x01\xc2H\xd2\x01\n" + + "f\n" + + "\x10string.ip_prefix\x12\x19must be a valid IP prefix\x1a7!rules.ip_prefix || this == '' || this.isIpPrefix(true)\n" + + "h\n" + + "\x16string.ip_prefix_empty\x12.value is empty, which is not a valid IP prefix\x1a\x1e!rules.ip_prefix || this != ''H\x00R\bipPrefix\x12\x89\x02\n" + + "\vipv4_prefix\x18\x1e \x01(\bB\xe5\x01\xc2H\xe1\x01\n" + + "o\n" + + "\x12string.ipv4_prefix\x12\x1bmust be a valid IPv4 prefix\x1a!rules.host_and_port || this == '' || this.isHostAndPort(true)\n" + + "y\n" + + "\x1astring.host_and_port_empty\x127value is empty, which is not a valid host and port pair\x1a\"!rules.host_and_port || this != ''H\x00R\vhostAndPort\x12\xf4\x01\n" + + "\x04ulid\x18# \x01(\bB\xdd\x01\xc2H\xd9\x01\n" + + "|\n" + + "\vstring.ulid\x12\x14must be a valid ULID\x1aW!rules.ulid || this == '' || this.matches('^[0-7][0-9A-HJKMNP-TV-Za-hjkmnp-tv-z]{25}$')\n" + + "Y\n" + + "\x11string.ulid_empty\x12)value is empty, which is not a valid ULID\x1a\x19!rules.ulid || this != ''H\x00R\x04ulid\x12\xe1\x02\n" + + "\fprotobuf_fqn\x18% \x01(\bB\xbb\x02\xc2H\xb7\x02\n" + + "\xaf\x01\n" + + "\x13string.protobuf_fqn\x12-must be a valid fully-qualified Protobuf name\x1ai!rules.protobuf_fqn || this == '' || this.matches('^[A-Za-z_][A-Za-z_0-9]*(\\\\.[A-Za-z_][A-Za-z_0-9]*)*$')\n" + + "\x82\x01\n" + + "\x19string.protobuf_fqn_empty\x12Bvalue is empty, which is not a valid fully-qualified Protobuf name\x1a!!rules.protobuf_fqn || this != ''H\x00R\vprotobufFqn\x12\xa1\x03\n" + + "\x10protobuf_dot_fqn\x18& \x01(\bB\xf4\x02\xc2H\xf0\x02\n" + + "\xcd\x01\n" + + "\x17string.protobuf_dot_fqn\x12@must be a valid fully-qualified Protobuf name with a leading dot\x1ap!rules.protobuf_dot_fqn || this == '' || this.matches('^\\\\.[A-Za-z_][A-Za-z_0-9]*(\\\\.[A-Za-z_][A-Za-z_0-9]*)*$')\n" + + "\x9d\x01\n" + + "\x1dstring.protobuf_dot_fqn_empty\x12Uvalue is empty, which is not a valid fully-qualified Protobuf name with a leading dot\x1a%!rules.protobuf_dot_fqn || this != ''H\x00R\x0eprotobufDotFqn\x12\xac\x05\n" + + "\x10well_known_regex\x18\x18 \x01(\x0e2\x18.buf.validate.KnownRegexB\xe5\x04\xc2H\xe1\x04\n" + + "\xea\x01\n" + + "#string.well_known_regex.header_name\x12 must be a valid HTTP header name\x1a\xa0\x01rules.well_known_regex != 1 || this == '' || this.matches(!has(rules.strict) || rules.strict ?'^:?[0-9a-zA-Z!#$%&\\'*+-.^_|~\\x60]+$' :'^[^\\u0000\\u000A\\u000D]+$')\n" + + "\x8d\x01\n" + + ")string.well_known_regex.header_name_empty\x125value is empty, which is not a valid HTTP header name\x1a)rules.well_known_regex != 1 || this != ''\n" + + "\xe1\x01\n" + + "$string.well_known_regex.header_value\x12!must be a valid HTTP header value\x1a\x95\x01rules.well_known_regex != 2 || this.matches(!has(rules.strict) || rules.strict ?'^[^\\u0000-\\u0008\\u000A-\\u001F\\u007F]*$' :'^[^\\u0000\\u000A\\u000D]*$')H\x00R\x0ewellKnownRegex\x12\x16\n" + + "\x06strict\x18\x19 \x01(\bR\x06strict\x125\n" + + "\aexample\x18\" \x03(\tB\x1b\xc2H\x18\n" + + "\x16\n" + + "\x0estring.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\f\n" + + "\n" + + "well_known\"\xca\x12\n" + + "\n" + + "BytesRules\x12\x81\x01\n" + + "\x05const\x18\x01 \x01(\fBk\xc2Hh\n" + + "f\n" + + "\vbytes.const\x1aWthis != getField(rules, 'const') ? 'must be %x'.format([getField(rules, 'const')]) : ''R\x05const\x12p\n" + + "\x03len\x18\r \x01(\x04B^\xc2H[\n" + + "Y\n" + + "\tbytes.len\x1aLuint(this.size()) != rules.len ? 'must be %s bytes'.format([rules.len]) : ''R\x03len\x12\x8b\x01\n" + + "\amin_len\x18\x02 \x01(\x04Br\xc2Ho\n" + + "m\n" + + "\rbytes.min_len\x1a\\uint(this.size()) < rules.min_len ? 'must be at least %s bytes'.format([rules.min_len]) : ''R\x06minLen\x12\x8a\x01\n" + + "\amax_len\x18\x03 \x01(\x04Bq\xc2Hn\n" + + "l\n" + + "\rbytes.max_len\x1a[uint(this.size()) > rules.max_len ? 'must be at most %s bytes'.format([rules.max_len]) : ''R\x06maxLen\x12\x93\x01\n" + + "\apattern\x18\x04 \x01(\tBy\xc2Hv\n" + + "t\n" + + "\rbytes.pattern\x1ac!string(this).matches(rules.pattern) ? 'must match regex pattern `%s`'.format([rules.pattern]) : ''R\apattern\x12\x83\x01\n" + + "\x06prefix\x18\x05 \x01(\fBk\xc2Hh\n" + + "f\n" + + "\fbytes.prefix\x1aV!this.startsWith(rules.prefix) ? 'does not have prefix %x'.format([rules.prefix]) : ''R\x06prefix\x12\x81\x01\n" + + "\x06suffix\x18\x06 \x01(\fBi\xc2Hf\n" + + "d\n" + + "\fbytes.suffix\x1aT!this.endsWith(rules.suffix) ? 'does not have suffix %x'.format([rules.suffix]) : ''R\x06suffix\x12\x87\x01\n" + + "\bcontains\x18\a \x01(\fBk\xc2Hh\n" + + "f\n" + + "\x0ebytes.contains\x1aT!this.contains(rules.contains) ? 'does not contain %x'.format([rules.contains]) : ''R\bcontains\x12\xa5\x01\n" + + "\x02in\x18\b \x03(\fB\x94\x01\xc2H\x90\x01\n" + + "\x8d\x01\n" + + "\bbytes.in\x1a\x80\x01getField(rules, 'in').size() > 0 && !(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12w\n" + + "\x06not_in\x18\t \x03(\fB`\xc2H]\n" + + "[\n" + + "\fbytes.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x12\xe9\x01\n" + + "\x02ip\x18\n" + + " \x01(\bB\xd6\x01\xc2H\xd2\x01\n" + + "n\n" + + "\bbytes.ip\x12\x1amust be a valid IP address\x1aF!rules.ip || this.size() == 0 || this.size() == 4 || this.size() == 16\n" + + "`\n" + + "\x0ebytes.ip_empty\x12/value is empty, which is not a valid IP address\x1a\x1d!rules.ip || this.size() != 0H\x00R\x02ip\x12\xe4\x01\n" + + "\x04ipv4\x18\v \x01(\bB\xcd\x01\xc2H\xc9\x01\n" + + "_\n" + + "\n" + + "bytes.ipv4\x12\x1cmust be a valid IPv4 address\x1a3!rules.ipv4 || this.size() == 0 || this.size() == 4\n" + + "f\n" + + "\x10bytes.ipv4_empty\x121value is empty, which is not a valid IPv4 address\x1a\x1f!rules.ipv4 || this.size() != 0H\x00R\x04ipv4\x12\xe5\x01\n" + + "\x04ipv6\x18\f \x01(\bB\xce\x01\xc2H\xca\x01\n" + + "`\n" + + "\n" + + "bytes.ipv6\x12\x1cmust be a valid IPv6 address\x1a4!rules.ipv6 || this.size() == 0 || this.size() == 16\n" + + "f\n" + + "\x10bytes.ipv6_empty\x121value is empty, which is not a valid IPv6 address\x1a\x1f!rules.ipv6 || this.size() != 0H\x00R\x04ipv6\x12\xd5\x01\n" + + "\x04uuid\x18\x0f \x01(\bB\xbe\x01\xc2H\xba\x01\n" + + "X\n" + + "\n" + + "bytes.uuid\x12\x14must be a valid UUID\x1a4!rules.uuid || this.size() == 0 || this.size() == 16\n" + + "^\n" + + "\x10bytes.uuid_empty\x12)value is empty, which is not a valid UUID\x1a\x1f!rules.uuid || this.size() != 0H\x00R\x04uuid\x124\n" + + "\aexample\x18\x0e \x03(\fB\x1a\xc2H\x17\n" + + "\x15\n" + + "\rbytes.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\f\n" + + "\n" + + "well_known\"\xea\x03\n" + + "\tEnumRules\x12\x83\x01\n" + + "\x05const\x18\x01 \x01(\x05Bm\xc2Hj\n" + + "h\n" + + "\n" + + "enum.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12!\n" + + "\fdefined_only\x18\x02 \x01(\bR\vdefinedOnly\x12|\n" + + "\x02in\x18\x03 \x03(\x05Bl\xc2Hi\n" + + "g\n" + + "\aenum.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12v\n" + + "\x06not_in\x18\x04 \x03(\x05B_\xc2H\\\n" + + "Z\n" + + "\venum.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x123\n" + + "\aexample\x18\x05 \x03(\x05B\x19\xc2H\x16\n" + + "\x14\n" + + "\fenum.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02\"\x90\x04\n" + + "\rRepeatedRules\x12\xa0\x01\n" + + "\tmin_items\x18\x01 \x01(\x04B\x82\x01\xc2H\x7f\n" + + "}\n" + + "\x12repeated.min_items\x1aguint(this.size()) < rules.min_items ? 'must contain at least %d item(s)'.format([rules.min_items]) : ''R\bminItems\x12\xa6\x01\n" + + "\tmax_items\x18\x02 \x01(\x04B\x88\x01\xc2H\x84\x01\n" + + "\x81\x01\n" + + "\x12repeated.max_items\x1akuint(this.size()) > rules.max_items ? 'must contain no more than %s item(s)'.format([rules.max_items]) : ''R\bmaxItems\x12x\n" + + "\x06unique\x18\x03 \x01(\bB`\xc2H]\n" + + "[\n" + + "\x0frepeated.unique\x12(repeated value must contain unique items\x1a\x1e!rules.unique || this.unique()R\x06unique\x12.\n" + + "\x05items\x18\x04 \x01(\v2\x18.buf.validate.FieldRulesR\x05items*\t\b\xe8\a\x10\x80\x80\x80\x80\x02\"\xac\x03\n" + + "\bMapRules\x12\x99\x01\n" + + "\tmin_pairs\x18\x01 \x01(\x04B|\xc2Hy\n" + + "w\n" + + "\rmap.min_pairs\x1afuint(this.size()) < rules.min_pairs ? 'map must be at least %d entries'.format([rules.min_pairs]) : ''R\bminPairs\x12\x98\x01\n" + + "\tmax_pairs\x18\x02 \x01(\x04B{\xc2Hx\n" + + "v\n" + + "\rmap.max_pairs\x1aeuint(this.size()) > rules.max_pairs ? 'map must be at most %d entries'.format([rules.max_pairs]) : ''R\bmaxPairs\x12,\n" + + "\x04keys\x18\x04 \x01(\v2\x18.buf.validate.FieldRulesR\x04keys\x120\n" + + "\x06values\x18\x05 \x01(\v2\x18.buf.validate.FieldRulesR\x06values*\t\b\xe8\a\x10\x80\x80\x80\x80\x02\"1\n" + + "\bAnyRules\x12\x0e\n" + + "\x02in\x18\x02 \x03(\tR\x02in\x12\x15\n" + + "\x06not_in\x18\x03 \x03(\tR\x05notIn\"\xec\x16\n" + + "\rDurationRules\x12\xa2\x01\n" + + "\x05const\x18\x02 \x01(\v2\x19.google.protobuf.DurationBq\xc2Hn\n" + + "l\n" + + "\x0eduration.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\xa6\x01\n" + + "\x02lt\x18\x03 \x01(\v2\x19.google.protobuf.DurationBy\xc2Hv\n" + + "t\n" + + "\vduration.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\xb9\x01\n" + + "\x03lte\x18\x04 \x01(\v2\x19.google.protobuf.DurationB\x89\x01\xc2H\x85\x01\n" + + "\x82\x01\n" + + "\fduration.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12\xa7\a\n" + + "\x02gt\x18\x05 \x01(\v2\x19.google.protobuf.DurationB\xf9\x06\xc2H\xf5\x06\n" + + "w\n" + + "\vduration.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xb0\x01\n" + + "\x0eduration.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb8\x01\n" + + "\x18duration.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xc0\x01\n" + + "\x0fduration.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc8\x01\n" + + "\x19duration.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xf4\a\n" + + "\x03gte\x18\x06 \x01(\v2\x19.google.protobuf.DurationB\xc4\a\xc2H\xc0\a\n" + + "\x85\x01\n" + + "\fduration.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xbf\x01\n" + + "\x0fduration.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc7\x01\n" + + "\x19duration.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xcf\x01\n" + + "\x10duration.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd7\x01\n" + + "\x1aduration.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12\x9b\x01\n" + + "\x02in\x18\a \x03(\v2\x19.google.protobuf.DurationBp\xc2Hm\n" + + "k\n" + + "\vduration.in\x1a\\!(this in getField(rules, 'in')) ? 'must be in list %s'.format([getField(rules, 'in')]) : ''R\x02in\x12\x95\x01\n" + + "\x06not_in\x18\b \x03(\v2\x19.google.protobuf.DurationBc\xc2H`\n" + + "^\n" + + "\x0fduration.not_in\x1aKthis in rules.not_in ? 'must not be in list %s'.format([rules.not_in]) : ''R\x05notIn\x12R\n" + + "\aexample\x18\t \x03(\v2\x19.google.protobuf.DurationB\x1d\xc2H\x1a\n" + + "\x18\n" + + "\x10duration.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"\x86\x06\n" + + "\x0eFieldMaskRules\x12\xc0\x01\n" + + "\x05const\x18\x01 \x01(\v2\x1a.google.protobuf.FieldMaskB\x8d\x01\xc2H\x89\x01\n" + + "\x86\x01\n" + + "\x10field_mask.const\x1arthis.paths != getField(rules, 'const').paths ? 'must equal paths %s'.format([getField(rules, 'const').paths]) : ''R\x05const\x12\xd7\x01\n" + + "\x02in\x18\x02 \x03(\tB\xc6\x01\xc2H\xc2\x01\n" + + "\xbf\x01\n" + + "\rfield_mask.in\x1a\xad\x01!this.paths.all(p, p in getField(rules, 'in') || getField(rules, 'in').exists(f, p.startsWith(f+'.'))) ? 'must only contain paths in %s'.format([getField(rules, 'in')]) : ''R\x02in\x12\xf4\x01\n" + + "\x06not_in\x18\x03 \x03(\tB\xdc\x01\xc2H\xd8\x01\n" + + "\xd5\x01\n" + + "\x11field_mask.not_in\x1a\xbf\x01!this.paths.all(p, !(p in getField(rules, 'not_in') || getField(rules, 'not_in').exists(f, p.startsWith(f+'.')))) ? 'must not contain any paths in %s'.format([getField(rules, 'not_in')]) : ''R\x05notIn\x12U\n" + + "\aexample\x18\x04 \x03(\v2\x1a.google.protobuf.FieldMaskB\x1f\xc2H\x1c\n" + + "\x1a\n" + + "\x12field_mask.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02\"\xe8\x17\n" + + "\x0eTimestampRules\x12\xa4\x01\n" + + "\x05const\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampBr\xc2Ho\n" + + "m\n" + + "\x0ftimestamp.const\x1aZthis != getField(rules, 'const') ? 'must equal %s'.format([getField(rules, 'const')]) : ''R\x05const\x12\xa8\x01\n" + + "\x02lt\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampBz\xc2Hw\n" + + "u\n" + + "\ftimestamp.lt\x1ae!has(rules.gte) && !has(rules.gt) && this >= rules.lt? 'must be less than %s'.format([rules.lt]) : ''H\x00R\x02lt\x12\xbb\x01\n" + + "\x03lte\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampB\x8a\x01\xc2H\x86\x01\n" + + "\x83\x01\n" + + "\rtimestamp.lte\x1ar!has(rules.gte) && !has(rules.gt) && this > rules.lte? 'must be less than or equal to %s'.format([rules.lte]) : ''H\x00R\x03lte\x12m\n" + + "\x06lt_now\x18\a \x01(\bBT\xc2HQ\n" + + "O\n" + + "\x10timestamp.lt_now\x1a;(rules.lt_now && this > now) ? 'must be less than now' : ''H\x00R\x05ltNow\x12\xad\a\n" + + "\x02gt\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\xfe\x06\xc2H\xfa\x06\n" + + "x\n" + + "\ftimestamp.gt\x1ah!has(rules.lt) && !has(rules.lte) && this <= rules.gt? 'must be greater than %s'.format([rules.gt]) : ''\n" + + "\xb1\x01\n" + + "\x0ftimestamp.gt_lt\x1a\x9d\x01has(rules.lt) && rules.lt >= rules.gt && (this >= rules.lt || this <= rules.gt)? 'must be greater than %s and less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xb9\x01\n" + + "\x19timestamp.gt_lt_exclusive\x1a\x9b\x01has(rules.lt) && rules.lt < rules.gt && (rules.lt <= this && this <= rules.gt)? 'must be greater than %s or less than %s'.format([rules.gt, rules.lt]) : ''\n" + + "\xc1\x01\n" + + "\x10timestamp.gt_lte\x1a\xac\x01has(rules.lte) && rules.lte >= rules.gt && (this > rules.lte || this <= rules.gt)? 'must be greater than %s and less than or equal to %s'.format([rules.gt, rules.lte]) : ''\n" + + "\xc9\x01\n" + + "\x1atimestamp.gt_lte_exclusive\x1a\xaa\x01has(rules.lte) && rules.lte < rules.gt && (rules.lte < this && this <= rules.gt)? 'must be greater than %s or less than or equal to %s'.format([rules.gt, rules.lte]) : ''H\x01R\x02gt\x12\xfa\a\n" + + "\x03gte\x18\x06 \x01(\v2\x1a.google.protobuf.TimestampB\xc9\a\xc2H\xc5\a\n" + + "\x86\x01\n" + + "\rtimestamp.gte\x1au!has(rules.lt) && !has(rules.lte) && this < rules.gte? 'must be greater than or equal to %s'.format([rules.gte]) : ''\n" + + "\xc0\x01\n" + + "\x10timestamp.gte_lt\x1a\xab\x01has(rules.lt) && rules.lt >= rules.gte && (this >= rules.lt || this < rules.gte)? 'must be greater than or equal to %s and less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xc8\x01\n" + + "\x1atimestamp.gte_lt_exclusive\x1a\xa9\x01has(rules.lt) && rules.lt < rules.gte && (rules.lt <= this && this < rules.gte)? 'must be greater than or equal to %s or less than %s'.format([rules.gte, rules.lt]) : ''\n" + + "\xd0\x01\n" + + "\x11timestamp.gte_lte\x1a\xba\x01has(rules.lte) && rules.lte >= rules.gte && (this > rules.lte || this < rules.gte)? 'must be greater than or equal to %s and less than or equal to %s'.format([rules.gte, rules.lte]) : ''\n" + + "\xd8\x01\n" + + "\x1btimestamp.gte_lte_exclusive\x1a\xb8\x01has(rules.lte) && rules.lte < rules.gte && (rules.lte < this && this < rules.gte)? 'must be greater than or equal to %s or less than or equal to %s'.format([rules.gte, rules.lte]) : ''H\x01R\x03gte\x12p\n" + + "\x06gt_now\x18\b \x01(\bBW\xc2HT\n" + + "R\n" + + "\x10timestamp.gt_now\x1a>(rules.gt_now && this < now) ? 'must be greater than now' : ''H\x01R\x05gtNow\x12\xb9\x01\n" + + "\x06within\x18\t \x01(\v2\x19.google.protobuf.DurationB\x85\x01\xc2H\x81\x01\n" + + "\x7f\n" + + "\x10timestamp.within\x1akthis < now-rules.within || this > now+rules.within ? 'must be within %s of now'.format([rules.within]) : ''R\x06within\x12T\n" + + "\aexample\x18\n" + + " \x03(\v2\x1a.google.protobuf.TimestampB\x1e\xc2H\x1b\n" + + "\x19\n" + + "\x11timestamp.example\x1a\x04trueR\aexample*\t\b\xe8\a\x10\x80\x80\x80\x80\x02B\v\n" + + "\tless_thanB\x0e\n" + + "\fgreater_than\"E\n" + + "\n" + + "Violations\x127\n" + + "\n" + + "violations\x18\x01 \x03(\v2\x17.buf.validate.ViolationR\n" + + "violations\"\xc5\x01\n" + + "\tViolation\x12-\n" + + "\x05field\x18\x05 \x01(\v2\x17.buf.validate.FieldPathR\x05field\x12+\n" + + "\x04rule\x18\x06 \x01(\v2\x17.buf.validate.FieldPathR\x04rule\x12\x17\n" + + "\arule_id\x18\x02 \x01(\tR\x06ruleId\x12\x18\n" + + "\amessage\x18\x03 \x01(\tR\amessage\x12\x17\n" + + "\afor_key\x18\x04 \x01(\bR\x06forKeyJ\x04\b\x01\x10\x02R\n" + + "field_path\"G\n" + + "\tFieldPath\x12:\n" + + "\belements\x18\x01 \x03(\v2\x1e.buf.validate.FieldPathElementR\belements\"\xcc\x03\n" + + "\x10FieldPathElement\x12!\n" + + "\ffield_number\x18\x01 \x01(\x05R\vfieldNumber\x12\x1d\n" + + "\n" + + "field_name\x18\x02 \x01(\tR\tfieldName\x12I\n" + + "\n" + + "field_type\x18\x03 \x01(\x0e2*.google.protobuf.FieldDescriptorProto.TypeR\tfieldType\x12E\n" + + "\bkey_type\x18\x04 \x01(\x0e2*.google.protobuf.FieldDescriptorProto.TypeR\akeyType\x12I\n" + + "\n" + + "value_type\x18\x05 \x01(\x0e2*.google.protobuf.FieldDescriptorProto.TypeR\tvalueType\x12\x16\n" + + "\x05index\x18\x06 \x01(\x04H\x00R\x05index\x12\x1b\n" + + "\bbool_key\x18\a \x01(\bH\x00R\aboolKey\x12\x19\n" + + "\aint_key\x18\b \x01(\x03H\x00R\x06intKey\x12\x1b\n" + + "\buint_key\x18\t \x01(\x04H\x00R\auintKey\x12\x1f\n" + + "\n" + + "string_key\x18\n" + + " \x01(\tH\x00R\tstringKeyB\v\n" + + "\tsubscript*\xa1\x01\n" + + "\x06Ignore\x12\x16\n" + + "\x12IGNORE_UNSPECIFIED\x10\x00\x12\x18\n" + + "\x14IGNORE_IF_ZERO_VALUE\x10\x01\x12\x11\n" + + "\rIGNORE_ALWAYS\x10\x03\"\x04\b\x02\x10\x02*\fIGNORE_EMPTY*\x0eIGNORE_DEFAULT*\x17IGNORE_IF_DEFAULT_VALUE*\x15IGNORE_IF_UNPOPULATED*n\n" + + "\n" + + "KnownRegex\x12\x1b\n" + + "\x17KNOWN_REGEX_UNSPECIFIED\x10\x00\x12 \n" + + "\x1cKNOWN_REGEX_HTTP_HEADER_NAME\x10\x01\x12!\n" + + "\x1dKNOWN_REGEX_HTTP_HEADER_VALUE\x10\x02:V\n" + + "\amessage\x12\x1f.google.protobuf.MessageOptions\x18\x87\t \x01(\v2\x1a.buf.validate.MessageRulesR\amessage:N\n" + + "\x05oneof\x12\x1d.google.protobuf.OneofOptions\x18\x87\t \x01(\v2\x18.buf.validate.OneofRulesR\x05oneof:N\n" + + "\x05field\x12\x1d.google.protobuf.FieldOptions\x18\x87\t \x01(\v2\x18.buf.validate.FieldRulesR\x05field:]\n" + + "\n" + + "predefined\x12\x1d.google.protobuf.FieldOptions\x18\x88\t \x01(\v2\x1d.buf.validate.PredefinedRulesR\n" + + "predefinedB\xae\x01\n" + + "\x10com.buf.validateB\rValidateProtoP\x01Z:github.com/wundergraph/cosmo/router/gen/proto/buf/validate\xa2\x02\x03BVX\xaa\x02\fBuf.Validate\xca\x02\fBuf\\Validate\xe2\x02\x18Buf\\Validate\\GPBMetadata\xea\x02\rBuf::Validate" + +var ( + file_buf_validate_validate_proto_rawDescOnce sync.Once + file_buf_validate_validate_proto_rawDescData []byte +) + +func file_buf_validate_validate_proto_rawDescGZIP() []byte { + file_buf_validate_validate_proto_rawDescOnce.Do(func() { + file_buf_validate_validate_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_buf_validate_validate_proto_rawDesc), len(file_buf_validate_validate_proto_rawDesc))) + }) + return file_buf_validate_validate_proto_rawDescData +} + +var file_buf_validate_validate_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_buf_validate_validate_proto_msgTypes = make([]protoimpl.MessageInfo, 32) +var file_buf_validate_validate_proto_goTypes = []any{ + (Ignore)(0), // 0: buf.validate.Ignore + (KnownRegex)(0), // 1: buf.validate.KnownRegex + (*Rule)(nil), // 2: buf.validate.Rule + (*MessageRules)(nil), // 3: buf.validate.MessageRules + (*MessageOneofRule)(nil), // 4: buf.validate.MessageOneofRule + (*OneofRules)(nil), // 5: buf.validate.OneofRules + (*FieldRules)(nil), // 6: buf.validate.FieldRules + (*PredefinedRules)(nil), // 7: buf.validate.PredefinedRules + (*FloatRules)(nil), // 8: buf.validate.FloatRules + (*DoubleRules)(nil), // 9: buf.validate.DoubleRules + (*Int32Rules)(nil), // 10: buf.validate.Int32Rules + (*Int64Rules)(nil), // 11: buf.validate.Int64Rules + (*UInt32Rules)(nil), // 12: buf.validate.UInt32Rules + (*UInt64Rules)(nil), // 13: buf.validate.UInt64Rules + (*SInt32Rules)(nil), // 14: buf.validate.SInt32Rules + (*SInt64Rules)(nil), // 15: buf.validate.SInt64Rules + (*Fixed32Rules)(nil), // 16: buf.validate.Fixed32Rules + (*Fixed64Rules)(nil), // 17: buf.validate.Fixed64Rules + (*SFixed32Rules)(nil), // 18: buf.validate.SFixed32Rules + (*SFixed64Rules)(nil), // 19: buf.validate.SFixed64Rules + (*BoolRules)(nil), // 20: buf.validate.BoolRules + (*StringRules)(nil), // 21: buf.validate.StringRules + (*BytesRules)(nil), // 22: buf.validate.BytesRules + (*EnumRules)(nil), // 23: buf.validate.EnumRules + (*RepeatedRules)(nil), // 24: buf.validate.RepeatedRules + (*MapRules)(nil), // 25: buf.validate.MapRules + (*AnyRules)(nil), // 26: buf.validate.AnyRules + (*DurationRules)(nil), // 27: buf.validate.DurationRules + (*FieldMaskRules)(nil), // 28: buf.validate.FieldMaskRules + (*TimestampRules)(nil), // 29: buf.validate.TimestampRules + (*Violations)(nil), // 30: buf.validate.Violations + (*Violation)(nil), // 31: buf.validate.Violation + (*FieldPath)(nil), // 32: buf.validate.FieldPath + (*FieldPathElement)(nil), // 33: buf.validate.FieldPathElement + (*durationpb.Duration)(nil), // 34: google.protobuf.Duration + (*fieldmaskpb.FieldMask)(nil), // 35: google.protobuf.FieldMask + (*timestamppb.Timestamp)(nil), // 36: google.protobuf.Timestamp + (descriptorpb.FieldDescriptorProto_Type)(0), // 37: google.protobuf.FieldDescriptorProto.Type + (*descriptorpb.MessageOptions)(nil), // 38: google.protobuf.MessageOptions + (*descriptorpb.OneofOptions)(nil), // 39: google.protobuf.OneofOptions + (*descriptorpb.FieldOptions)(nil), // 40: google.protobuf.FieldOptions +} +var file_buf_validate_validate_proto_depIdxs = []int32{ + 2, // 0: buf.validate.MessageRules.cel:type_name -> buf.validate.Rule + 4, // 1: buf.validate.MessageRules.oneof:type_name -> buf.validate.MessageOneofRule + 2, // 2: buf.validate.FieldRules.cel:type_name -> buf.validate.Rule + 0, // 3: buf.validate.FieldRules.ignore:type_name -> buf.validate.Ignore + 8, // 4: buf.validate.FieldRules.float:type_name -> buf.validate.FloatRules + 9, // 5: buf.validate.FieldRules.double:type_name -> buf.validate.DoubleRules + 10, // 6: buf.validate.FieldRules.int32:type_name -> buf.validate.Int32Rules + 11, // 7: buf.validate.FieldRules.int64:type_name -> buf.validate.Int64Rules + 12, // 8: buf.validate.FieldRules.uint32:type_name -> buf.validate.UInt32Rules + 13, // 9: buf.validate.FieldRules.uint64:type_name -> buf.validate.UInt64Rules + 14, // 10: buf.validate.FieldRules.sint32:type_name -> buf.validate.SInt32Rules + 15, // 11: buf.validate.FieldRules.sint64:type_name -> buf.validate.SInt64Rules + 16, // 12: buf.validate.FieldRules.fixed32:type_name -> buf.validate.Fixed32Rules + 17, // 13: buf.validate.FieldRules.fixed64:type_name -> buf.validate.Fixed64Rules + 18, // 14: buf.validate.FieldRules.sfixed32:type_name -> buf.validate.SFixed32Rules + 19, // 15: buf.validate.FieldRules.sfixed64:type_name -> buf.validate.SFixed64Rules + 20, // 16: buf.validate.FieldRules.bool:type_name -> buf.validate.BoolRules + 21, // 17: buf.validate.FieldRules.string:type_name -> buf.validate.StringRules + 22, // 18: buf.validate.FieldRules.bytes:type_name -> buf.validate.BytesRules + 23, // 19: buf.validate.FieldRules.enum:type_name -> buf.validate.EnumRules + 24, // 20: buf.validate.FieldRules.repeated:type_name -> buf.validate.RepeatedRules + 25, // 21: buf.validate.FieldRules.map:type_name -> buf.validate.MapRules + 26, // 22: buf.validate.FieldRules.any:type_name -> buf.validate.AnyRules + 27, // 23: buf.validate.FieldRules.duration:type_name -> buf.validate.DurationRules + 28, // 24: buf.validate.FieldRules.field_mask:type_name -> buf.validate.FieldMaskRules + 29, // 25: buf.validate.FieldRules.timestamp:type_name -> buf.validate.TimestampRules + 2, // 26: buf.validate.PredefinedRules.cel:type_name -> buf.validate.Rule + 1, // 27: buf.validate.StringRules.well_known_regex:type_name -> buf.validate.KnownRegex + 6, // 28: buf.validate.RepeatedRules.items:type_name -> buf.validate.FieldRules + 6, // 29: buf.validate.MapRules.keys:type_name -> buf.validate.FieldRules + 6, // 30: buf.validate.MapRules.values:type_name -> buf.validate.FieldRules + 34, // 31: buf.validate.DurationRules.const:type_name -> google.protobuf.Duration + 34, // 32: buf.validate.DurationRules.lt:type_name -> google.protobuf.Duration + 34, // 33: buf.validate.DurationRules.lte:type_name -> google.protobuf.Duration + 34, // 34: buf.validate.DurationRules.gt:type_name -> google.protobuf.Duration + 34, // 35: buf.validate.DurationRules.gte:type_name -> google.protobuf.Duration + 34, // 36: buf.validate.DurationRules.in:type_name -> google.protobuf.Duration + 34, // 37: buf.validate.DurationRules.not_in:type_name -> google.protobuf.Duration + 34, // 38: buf.validate.DurationRules.example:type_name -> google.protobuf.Duration + 35, // 39: buf.validate.FieldMaskRules.const:type_name -> google.protobuf.FieldMask + 35, // 40: buf.validate.FieldMaskRules.example:type_name -> google.protobuf.FieldMask + 36, // 41: buf.validate.TimestampRules.const:type_name -> google.protobuf.Timestamp + 36, // 42: buf.validate.TimestampRules.lt:type_name -> google.protobuf.Timestamp + 36, // 43: buf.validate.TimestampRules.lte:type_name -> google.protobuf.Timestamp + 36, // 44: buf.validate.TimestampRules.gt:type_name -> google.protobuf.Timestamp + 36, // 45: buf.validate.TimestampRules.gte:type_name -> google.protobuf.Timestamp + 34, // 46: buf.validate.TimestampRules.within:type_name -> google.protobuf.Duration + 36, // 47: buf.validate.TimestampRules.example:type_name -> google.protobuf.Timestamp + 31, // 48: buf.validate.Violations.violations:type_name -> buf.validate.Violation + 32, // 49: buf.validate.Violation.field:type_name -> buf.validate.FieldPath + 32, // 50: buf.validate.Violation.rule:type_name -> buf.validate.FieldPath + 33, // 51: buf.validate.FieldPath.elements:type_name -> buf.validate.FieldPathElement + 37, // 52: buf.validate.FieldPathElement.field_type:type_name -> google.protobuf.FieldDescriptorProto.Type + 37, // 53: buf.validate.FieldPathElement.key_type:type_name -> google.protobuf.FieldDescriptorProto.Type + 37, // 54: buf.validate.FieldPathElement.value_type:type_name -> google.protobuf.FieldDescriptorProto.Type + 38, // 55: buf.validate.message:extendee -> google.protobuf.MessageOptions + 39, // 56: buf.validate.oneof:extendee -> google.protobuf.OneofOptions + 40, // 57: buf.validate.field:extendee -> google.protobuf.FieldOptions + 40, // 58: buf.validate.predefined:extendee -> google.protobuf.FieldOptions + 3, // 59: buf.validate.message:type_name -> buf.validate.MessageRules + 5, // 60: buf.validate.oneof:type_name -> buf.validate.OneofRules + 6, // 61: buf.validate.field:type_name -> buf.validate.FieldRules + 7, // 62: buf.validate.predefined:type_name -> buf.validate.PredefinedRules + 63, // [63:63] is the sub-list for method output_type + 63, // [63:63] is the sub-list for method input_type + 59, // [59:63] is the sub-list for extension type_name + 55, // [55:59] is the sub-list for extension extendee + 0, // [0:55] is the sub-list for field type_name +} + +func init() { file_buf_validate_validate_proto_init() } +func file_buf_validate_validate_proto_init() { + if File_buf_validate_validate_proto != nil { + return + } + file_buf_validate_validate_proto_msgTypes[4].OneofWrappers = []any{ + (*FieldRules_Float)(nil), + (*FieldRules_Double)(nil), + (*FieldRules_Int32)(nil), + (*FieldRules_Int64)(nil), + (*FieldRules_Uint32)(nil), + (*FieldRules_Uint64)(nil), + (*FieldRules_Sint32)(nil), + (*FieldRules_Sint64)(nil), + (*FieldRules_Fixed32)(nil), + (*FieldRules_Fixed64)(nil), + (*FieldRules_Sfixed32)(nil), + (*FieldRules_Sfixed64)(nil), + (*FieldRules_Bool)(nil), + (*FieldRules_String_)(nil), + (*FieldRules_Bytes)(nil), + (*FieldRules_Enum)(nil), + (*FieldRules_Repeated)(nil), + (*FieldRules_Map)(nil), + (*FieldRules_Any)(nil), + (*FieldRules_Duration)(nil), + (*FieldRules_FieldMask)(nil), + (*FieldRules_Timestamp)(nil), + } + file_buf_validate_validate_proto_msgTypes[6].OneofWrappers = []any{ + (*FloatRules_Lt)(nil), + (*FloatRules_Lte)(nil), + (*FloatRules_Gt)(nil), + (*FloatRules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[7].OneofWrappers = []any{ + (*DoubleRules_Lt)(nil), + (*DoubleRules_Lte)(nil), + (*DoubleRules_Gt)(nil), + (*DoubleRules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[8].OneofWrappers = []any{ + (*Int32Rules_Lt)(nil), + (*Int32Rules_Lte)(nil), + (*Int32Rules_Gt)(nil), + (*Int32Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[9].OneofWrappers = []any{ + (*Int64Rules_Lt)(nil), + (*Int64Rules_Lte)(nil), + (*Int64Rules_Gt)(nil), + (*Int64Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[10].OneofWrappers = []any{ + (*UInt32Rules_Lt)(nil), + (*UInt32Rules_Lte)(nil), + (*UInt32Rules_Gt)(nil), + (*UInt32Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[11].OneofWrappers = []any{ + (*UInt64Rules_Lt)(nil), + (*UInt64Rules_Lte)(nil), + (*UInt64Rules_Gt)(nil), + (*UInt64Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[12].OneofWrappers = []any{ + (*SInt32Rules_Lt)(nil), + (*SInt32Rules_Lte)(nil), + (*SInt32Rules_Gt)(nil), + (*SInt32Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[13].OneofWrappers = []any{ + (*SInt64Rules_Lt)(nil), + (*SInt64Rules_Lte)(nil), + (*SInt64Rules_Gt)(nil), + (*SInt64Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[14].OneofWrappers = []any{ + (*Fixed32Rules_Lt)(nil), + (*Fixed32Rules_Lte)(nil), + (*Fixed32Rules_Gt)(nil), + (*Fixed32Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[15].OneofWrappers = []any{ + (*Fixed64Rules_Lt)(nil), + (*Fixed64Rules_Lte)(nil), + (*Fixed64Rules_Gt)(nil), + (*Fixed64Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[16].OneofWrappers = []any{ + (*SFixed32Rules_Lt)(nil), + (*SFixed32Rules_Lte)(nil), + (*SFixed32Rules_Gt)(nil), + (*SFixed32Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[17].OneofWrappers = []any{ + (*SFixed64Rules_Lt)(nil), + (*SFixed64Rules_Lte)(nil), + (*SFixed64Rules_Gt)(nil), + (*SFixed64Rules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[19].OneofWrappers = []any{ + (*StringRules_Email)(nil), + (*StringRules_Hostname)(nil), + (*StringRules_Ip)(nil), + (*StringRules_Ipv4)(nil), + (*StringRules_Ipv6)(nil), + (*StringRules_Uri)(nil), + (*StringRules_UriRef)(nil), + (*StringRules_Address)(nil), + (*StringRules_Uuid)(nil), + (*StringRules_Tuuid)(nil), + (*StringRules_IpWithPrefixlen)(nil), + (*StringRules_Ipv4WithPrefixlen)(nil), + (*StringRules_Ipv6WithPrefixlen)(nil), + (*StringRules_IpPrefix)(nil), + (*StringRules_Ipv4Prefix)(nil), + (*StringRules_Ipv6Prefix)(nil), + (*StringRules_HostAndPort)(nil), + (*StringRules_Ulid)(nil), + (*StringRules_ProtobufFqn)(nil), + (*StringRules_ProtobufDotFqn)(nil), + (*StringRules_WellKnownRegex)(nil), + } + file_buf_validate_validate_proto_msgTypes[20].OneofWrappers = []any{ + (*BytesRules_Ip)(nil), + (*BytesRules_Ipv4)(nil), + (*BytesRules_Ipv6)(nil), + (*BytesRules_Uuid)(nil), + } + file_buf_validate_validate_proto_msgTypes[25].OneofWrappers = []any{ + (*DurationRules_Lt)(nil), + (*DurationRules_Lte)(nil), + (*DurationRules_Gt)(nil), + (*DurationRules_Gte)(nil), + } + file_buf_validate_validate_proto_msgTypes[27].OneofWrappers = []any{ + (*TimestampRules_Lt)(nil), + (*TimestampRules_Lte)(nil), + (*TimestampRules_LtNow)(nil), + (*TimestampRules_Gt)(nil), + (*TimestampRules_Gte)(nil), + (*TimestampRules_GtNow)(nil), + } + file_buf_validate_validate_proto_msgTypes[31].OneofWrappers = []any{ + (*FieldPathElement_Index)(nil), + (*FieldPathElement_BoolKey)(nil), + (*FieldPathElement_IntKey)(nil), + (*FieldPathElement_UintKey)(nil), + (*FieldPathElement_StringKey)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_buf_validate_validate_proto_rawDesc), len(file_buf_validate_validate_proto_rawDesc)), + NumEnums: 2, + NumMessages: 32, + NumExtensions: 4, + NumServices: 0, + }, + GoTypes: file_buf_validate_validate_proto_goTypes, + DependencyIndexes: file_buf_validate_validate_proto_depIdxs, + EnumInfos: file_buf_validate_validate_proto_enumTypes, + MessageInfos: file_buf_validate_validate_proto_msgTypes, + ExtensionInfos: file_buf_validate_validate_proto_extTypes, + }.Build() + File_buf_validate_validate_proto = out.File + file_buf_validate_validate_proto_goTypes = nil + file_buf_validate_validate_proto_depIdxs = nil +} diff --git a/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go index c1fb97bbbd..69f831904e 100644 --- a/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go +++ b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go @@ -7,6 +7,7 @@ package yokov1 import ( + _ "github.com/wundergraph/cosmo/router/gen/proto/buf/validate" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -21,79 +22,29 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -type OperationKind int32 - -const ( - OperationKind_OPERATION_KIND_UNSPECIFIED OperationKind = 0 - OperationKind_OPERATION_KIND_QUERY OperationKind = 1 - OperationKind_OPERATION_KIND_MUTATION OperationKind = 2 -) - -// Enum value maps for OperationKind. -var ( - OperationKind_name = map[int32]string{ - 0: "OPERATION_KIND_UNSPECIFIED", - 1: "OPERATION_KIND_QUERY", - 2: "OPERATION_KIND_MUTATION", - } - OperationKind_value = map[string]int32{ - "OPERATION_KIND_UNSPECIFIED": 0, - "OPERATION_KIND_QUERY": 1, - "OPERATION_KIND_MUTATION": 2, - } -) - -func (x OperationKind) Enum() *OperationKind { - p := new(OperationKind) - *p = x - return p -} - -func (x OperationKind) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (OperationKind) Descriptor() protoreflect.EnumDescriptor { - return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes[0].Descriptor() -} - -func (OperationKind) Type() protoreflect.EnumType { - return &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes[0] -} - -func (x OperationKind) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use OperationKind.Descriptor instead. -func (OperationKind) EnumDescriptor() ([]byte, []int) { - return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{0} -} - -type IndexRequest struct { +type IndexSchemaRequest struct { state protoimpl.MessageState `protogen:"open.v1"` - // The supergraph SDL to index. Sent in full on every Index call; - // Yoko deduplicates internally and is free to short-circuit when - // the SDL is already known. - SchemaSdl string `protobuf:"bytes,1,opt,name=schema_sdl,json=schemaSdl,proto3" json:"schema_sdl,omitempty"` + // GraphQL Schema Definition Language (SDL) for the target API. + // Must contain at least one non-whitespace character. + Sdl string `protobuf:"bytes,1,opt,name=sdl,proto3" json:"sdl,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *IndexRequest) Reset() { - *x = IndexRequest{} +func (x *IndexSchemaRequest) Reset() { + *x = IndexSchemaRequest{} mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *IndexRequest) String() string { +func (x *IndexSchemaRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*IndexRequest) ProtoMessage() {} +func (*IndexSchemaRequest) ProtoMessage() {} -func (x *IndexRequest) ProtoReflect() protoreflect.Message { +func (x *IndexSchemaRequest) ProtoReflect() protoreflect.Message { mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -105,43 +56,40 @@ func (x *IndexRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use IndexRequest.ProtoReflect.Descriptor instead. -func (*IndexRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use IndexSchemaRequest.ProtoReflect.Descriptor instead. +func (*IndexSchemaRequest) Descriptor() ([]byte, []int) { return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{0} } -func (x *IndexRequest) GetSchemaSdl() string { +func (x *IndexSchemaRequest) GetSdl() string { if x != nil { - return x.SchemaSdl + return x.Sdl } return "" } -type IndexResponse struct { +type IndexSchemaResponse struct { state protoimpl.MessageState `protogen:"open.v1"` - // Opaque, Yoko-assigned identifier for this schema. Stable for as - // long as Yoko retains the index. Subsequent Search calls pass this - // back instead of the full SDL. Idempotent: the same SDL returns - // the same schema_id. + // Stable id derived from the indexed SDL; pass to GenerateQuery. SchemaId string `protobuf:"bytes,1,opt,name=schema_id,json=schemaId,proto3" json:"schema_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *IndexResponse) Reset() { - *x = IndexResponse{} +func (x *IndexSchemaResponse) Reset() { + *x = IndexSchemaResponse{} mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *IndexResponse) String() string { +func (x *IndexSchemaResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*IndexResponse) ProtoMessage() {} +func (*IndexSchemaResponse) ProtoMessage() {} -func (x *IndexResponse) ProtoReflect() protoreflect.Message { +func (x *IndexSchemaResponse) ProtoReflect() protoreflect.Message { mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -153,49 +101,42 @@ func (x *IndexResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use IndexResponse.ProtoReflect.Descriptor instead. -func (*IndexResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use IndexSchemaResponse.ProtoReflect.Descriptor instead. +func (*IndexSchemaResponse) Descriptor() ([]byte, []int) { return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{1} } -func (x *IndexResponse) GetSchemaId() string { +func (x *IndexSchemaResponse) GetSchemaId() string { if x != nil { return x.SchemaId } return "" } -type SearchRequest struct { +type GenerateQueryRequest struct { state protoimpl.MessageState `protogen:"open.v1"` - // Batch of natural-language prompts. Bounded at 20 by the host. - Prompts []string `protobuf:"bytes,1,rep,name=prompts,proto3" json:"prompts,omitempty"` - // Identifier returned by a prior Index call. If Yoko no longer - // recognizes the id (e.g. eviction, restart), it MUST return the - // Connect error code NOT_FOUND; the router re-indexes and retries - // the call exactly once. - SchemaId string `protobuf:"bytes,2,opt,name=schema_id,json=schemaId,proto3" json:"schema_id,omitempty"` - // Opaque MCP session ID for telemetry correlation only. - // Yoko MUST NOT use this for stateful behavior — sessions are owned - // by the router. - SessionId string `protobuf:"bytes,3,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + // schema_id from a prior IndexSchema call. + SchemaId string `protobuf:"bytes,1,opt,name=schema_id,json=schemaId,proto3" json:"schema_id,omitempty"` + // Natural-language description of what the caller wants to fetch. + Prompt string `protobuf:"bytes,2,opt,name=prompt,proto3" json:"prompt,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *SearchRequest) Reset() { - *x = SearchRequest{} +func (x *GenerateQueryRequest) Reset() { + *x = GenerateQueryRequest{} mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *SearchRequest) String() string { +func (x *GenerateQueryRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SearchRequest) ProtoMessage() {} +func (*GenerateQueryRequest) ProtoMessage() {} -func (x *SearchRequest) ProtoReflect() protoreflect.Message { +func (x *GenerateQueryRequest) ProtoReflect() protoreflect.Message { mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -207,56 +148,46 @@ func (x *SearchRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SearchRequest.ProtoReflect.Descriptor instead. -func (*SearchRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use GenerateQueryRequest.ProtoReflect.Descriptor instead. +func (*GenerateQueryRequest) Descriptor() ([]byte, []int) { return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{2} } -func (x *SearchRequest) GetPrompts() []string { - if x != nil { - return x.Prompts - } - return nil -} - -func (x *SearchRequest) GetSchemaId() string { +func (x *GenerateQueryRequest) GetSchemaId() string { if x != nil { return x.SchemaId } return "" } -func (x *SearchRequest) GetSessionId() string { +func (x *GenerateQueryRequest) GetPrompt() string { if x != nil { - return x.SessionId + return x.Prompt } return "" } -type SearchResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Operations across all prompts, already deduplicated and ranked. - // Order is significant: earlier entries rank higher and are preferred - // when bundle truncation drops from the tail. - Operations []*GeneratedOperation `protobuf:"bytes,1,rep,name=operations,proto3" json:"operations,omitempty"` +type GenerateQueryResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Resolution *Resolution `protobuf:"bytes,1,opt,name=resolution,proto3" json:"resolution,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *SearchResponse) Reset() { - *x = SearchResponse{} +func (x *GenerateQueryResponse) Reset() { + *x = GenerateQueryResponse{} mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *SearchResponse) String() string { +func (x *GenerateQueryResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SearchResponse) ProtoMessage() {} +func (*GenerateQueryResponse) ProtoMessage() {} -func (x *SearchResponse) ProtoReflect() protoreflect.Message { +func (x *GenerateQueryResponse) ProtoReflect() protoreflect.Message { mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -268,50 +199,47 @@ func (x *SearchResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SearchResponse.ProtoReflect.Descriptor instead. -func (*SearchResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use GenerateQueryResponse.ProtoReflect.Descriptor instead. +func (*GenerateQueryResponse) Descriptor() ([]byte, []int) { return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{3} } -func (x *SearchResponse) GetOperations() []*GeneratedOperation { +func (x *GenerateQueryResponse) GetResolution() *Resolution { if x != nil { - return x.Operations + return x.Resolution } return nil } -type GeneratedOperation struct { +type Resolution struct { state protoimpl.MessageState `protogen:"open.v1"` - // Suggested operation name (camelCase preferred). The host applies - // its own identifier normalization and in-session collision-suffix - // logic on top of this — see §6. - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - // GraphQL operation body (query or mutation source text). - Body string `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"` - // Operation kind. Subscriptions are out of scope; if Yoko returns - // one, the host drops it with a single warn log. - Kind OperationKind `protobuf:"varint,3,opt,name=kind,proto3,enum=wundergraph.cosmo.code_mode.yoko.v1.OperationKind" json:"kind,omitempty"` - // Human-readable description, surfaced as JSDoc on the typed - // `tools.` signature in the rendered bundle. - Description string `protobuf:"bytes,4,opt,name=description,proto3" json:"description,omitempty"` + // One entry per produced query; each is a self-contained operation + // with a natural-language description of what it does. + Queries []*ResolvedQuery `protobuf:"bytes,1,rep,name=queries,proto3" json:"queries,omitempty"` + // One entry per requirement we could not satisfy; each carries a + // natural-language reason. + Unsatisfied []*Unsatisfied `protobuf:"bytes,2,rep,name=unsatisfied,proto3" json:"unsatisfied,omitempty"` + // True when the propose agent ran out of turns before committing + // every requirement; clients may want to retry with a tighter prompt. + Truncated bool `protobuf:"varint,3,opt,name=truncated,proto3" json:"truncated,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *GeneratedOperation) Reset() { - *x = GeneratedOperation{} +func (x *Resolution) Reset() { + *x = Resolution{} mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *GeneratedOperation) String() string { +func (x *Resolution) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GeneratedOperation) ProtoMessage() {} +func (*Resolution) ProtoMessage() {} -func (x *GeneratedOperation) ProtoReflect() protoreflect.Message { +func (x *Resolution) ProtoReflect() protoreflect.Message { mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -323,35 +251,159 @@ func (x *GeneratedOperation) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GeneratedOperation.ProtoReflect.Descriptor instead. -func (*GeneratedOperation) Descriptor() ([]byte, []int) { +// Deprecated: Use Resolution.ProtoReflect.Descriptor instead. +func (*Resolution) Descriptor() ([]byte, []int) { return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{4} } -func (x *GeneratedOperation) GetName() string { +func (x *Resolution) GetQueries() []*ResolvedQuery { + if x != nil { + return x.Queries + } + return nil +} + +func (x *Resolution) GetUnsatisfied() []*Unsatisfied { + if x != nil { + return x.Unsatisfied + } + return nil +} + +func (x *Resolution) GetTruncated() bool { + if x != nil { + return x.Truncated + } + return false +} + +type ResolvedQuery struct { + state protoimpl.MessageState `protogen:"open.v1"` + // One short user-facing sentence describing what this query does. + Description string `protobuf:"bytes,1,opt,name=description,proto3" json:"description,omitempty"` + // GraphQL operation document — exactly one named operation. + Document string `protobuf:"bytes,2,opt,name=document,proto3" json:"document,omitempty"` + // Operation name parsed from the document (e.g. "GetUserPosts"). + OperationName string `protobuf:"bytes,3,opt,name=operation_name,json=operationName,proto3" json:"operation_name,omitempty"` + // One of "query", "mutation", "subscription". + OperationType string `protobuf:"bytes,4,opt,name=operation_type,json=operationType,proto3" json:"operation_type,omitempty"` + // JSON Schema for the operation's $variables object, derived + // statically from the document. Carried as a JSON-encoded string so + // JSON clients see a readable schema (a `bytes` field would surface + // as base64 over JSON transport). + VariablesSchema string `protobuf:"bytes,5,opt,name=variables_schema,json=variablesSchema,proto3" json:"variables_schema,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ResolvedQuery) Reset() { + *x = ResolvedQuery{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ResolvedQuery) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ResolvedQuery) ProtoMessage() {} + +func (x *ResolvedQuery) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ResolvedQuery.ProtoReflect.Descriptor instead. +func (*ResolvedQuery) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{5} +} + +func (x *ResolvedQuery) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + +func (x *ResolvedQuery) GetDocument() string { if x != nil { - return x.Name + return x.Document } return "" } -func (x *GeneratedOperation) GetBody() string { +func (x *ResolvedQuery) GetOperationName() string { if x != nil { - return x.Body + return x.OperationName } return "" } -func (x *GeneratedOperation) GetKind() OperationKind { +func (x *ResolvedQuery) GetOperationType() string { if x != nil { - return x.Kind + return x.OperationType } - return OperationKind_OPERATION_KIND_UNSPECIFIED + return "" } -func (x *GeneratedOperation) GetDescription() string { +func (x *ResolvedQuery) GetVariablesSchema() string { if x != nil { - return x.Description + return x.VariablesSchema + } + return "" +} + +type Unsatisfied struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Natural-language explanation of why this requirement could not + // be satisfied (e.g. "no field on the schema carries that filter + // dimension"). + Reason string `protobuf:"bytes,1,opt,name=reason,proto3" json:"reason,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Unsatisfied) Reset() { + *x = Unsatisfied{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Unsatisfied) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Unsatisfied) ProtoMessage() {} + +func (x *Unsatisfied) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Unsatisfied.ProtoReflect.Descriptor instead. +func (*Unsatisfied) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{6} +} + +func (x *Unsatisfied) GetReason() string { + if x != nil { + return x.Reason } return "" } @@ -360,34 +412,35 @@ var File_wg_cosmo_code_mode_yoko_v1_yoko_proto protoreflect.FileDescriptor const file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc = "" + "\n" + - "%wg/cosmo/code_mode/yoko/v1/yoko.proto\x12#wundergraph.cosmo.code_mode.yoko.v1\"-\n" + - "\fIndexRequest\x12\x1d\n" + - "\n" + - "schema_sdl\x18\x01 \x01(\tR\tschemaSdl\",\n" + - "\rIndexResponse\x12\x1b\n" + - "\tschema_id\x18\x01 \x01(\tR\bschemaId\"e\n" + - "\rSearchRequest\x12\x18\n" + - "\aprompts\x18\x01 \x03(\tR\aprompts\x12\x1b\n" + - "\tschema_id\x18\x02 \x01(\tR\bschemaId\x12\x1d\n" + + "%wg/cosmo/code_mode/yoko/v1/yoko.proto\x12\ayoko.v1\x1a\x1bbuf/validate/validate.proto\"1\n" + + "\x12IndexSchemaRequest\x12\x1b\n" + + "\x03sdl\x18\x01 \x01(\tB\t\xbaH\x06r\x042\x02\\SR\x03sdl\"2\n" + + "\x13IndexSchemaResponse\x12\x1b\n" + + "\tschema_id\x18\x01 \x01(\tR\bschemaId\"a\n" + + "\x14GenerateQueryRequest\x12&\n" + + "\tschema_id\x18\x01 \x01(\tB\t\xbaH\x06r\x042\x02\\SR\bschemaId\x12!\n" + + "\x06prompt\x18\x02 \x01(\tB\t\xbaH\x06r\x042\x02\\SR\x06prompt\"L\n" + + "\x15GenerateQueryResponse\x123\n" + "\n" + - "session_id\x18\x03 \x01(\tR\tsessionId\"i\n" + - "\x0eSearchResponse\x12W\n" + + "resolution\x18\x01 \x01(\v2\x13.yoko.v1.ResolutionR\n" + + "resolution\"\x94\x01\n" + "\n" + - "operations\x18\x01 \x03(\v27.wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperationR\n" + - "operations\"\xa6\x01\n" + - "\x12GeneratedOperation\x12\x12\n" + - "\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n" + - "\x04body\x18\x02 \x01(\tR\x04body\x12F\n" + - "\x04kind\x18\x03 \x01(\x0e22.wundergraph.cosmo.code_mode.yoko.v1.OperationKindR\x04kind\x12 \n" + - "\vdescription\x18\x04 \x01(\tR\vdescription*f\n" + - "\rOperationKind\x12\x1e\n" + - "\x1aOPERATION_KIND_UNSPECIFIED\x10\x00\x12\x18\n" + - "\x14OPERATION_KIND_QUERY\x10\x01\x12\x1b\n" + - "\x17OPERATION_KIND_MUTATION\x10\x022\xf0\x01\n" + - "\vYokoService\x12n\n" + - "\x05Index\x121.wundergraph.cosmo.code_mode.yoko.v1.IndexRequest\x1a2.wundergraph.cosmo.code_mode.yoko.v1.IndexResponse\x12q\n" + - "\x06Search\x122.wundergraph.cosmo.code_mode.yoko.v1.SearchRequest\x1a3.wundergraph.cosmo.code_mode.yoko.v1.SearchResponseB\xb2\x02\n" + - "'com.wundergraph.cosmo.code_mode.yoko.v1B\tYokoProtoP\x01ZOgithub.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1;yokov1\xa2\x02\x04WCCY\xaa\x02\"Wundergraph.Cosmo.CodeMode.Yoko.V1\xca\x02\"Wundergraph\\Cosmo\\CodeMode\\Yoko\\V1\xe2\x02.Wundergraph\\Cosmo\\CodeMode\\Yoko\\V1\\GPBMetadata\xea\x02&Wundergraph::Cosmo::CodeMode::Yoko::V1b\x06proto3" + "Resolution\x120\n" + + "\aqueries\x18\x01 \x03(\v2\x16.yoko.v1.ResolvedQueryR\aqueries\x126\n" + + "\vunsatisfied\x18\x02 \x03(\v2\x14.yoko.v1.UnsatisfiedR\vunsatisfied\x12\x1c\n" + + "\ttruncated\x18\x03 \x01(\bR\ttruncated\"\xc6\x01\n" + + "\rResolvedQuery\x12 \n" + + "\vdescription\x18\x01 \x01(\tR\vdescription\x12\x1a\n" + + "\bdocument\x18\x02 \x01(\tR\bdocument\x12%\n" + + "\x0eoperation_name\x18\x03 \x01(\tR\roperationName\x12%\n" + + "\x0eoperation_type\x18\x04 \x01(\tR\roperationType\x12)\n" + + "\x10variables_schema\x18\x05 \x01(\tR\x0fvariablesSchema\"%\n" + + "\vUnsatisfied\x12\x16\n" + + "\x06reason\x18\x01 \x01(\tR\x06reason2\xa7\x01\n" + + "\vYokoService\x12H\n" + + "\vIndexSchema\x12\x1b.yoko.v1.IndexSchemaRequest\x1a\x1c.yoko.v1.IndexSchemaResponse\x12N\n" + + "\rGenerateQuery\x12\x1d.yoko.v1.GenerateQueryRequest\x1a\x1e.yoko.v1.GenerateQueryResponseB\xa6\x01\n" + + "\vcom.yoko.v1B\tYokoProtoP\x01ZOgithub.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1;yokov1\xa2\x02\x03YXX\xaa\x02\aYoko.V1\xca\x02\aYoko\\V1\xe2\x02\x13Yoko\\V1\\GPBMetadata\xea\x02\bYoko::V1b\x06proto3" var ( file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescOnce sync.Once @@ -401,28 +454,29 @@ func file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP() []byte { return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescData } -var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes = make([]protoimpl.MessageInfo, 7) var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_goTypes = []any{ - (OperationKind)(0), // 0: wundergraph.cosmo.code_mode.yoko.v1.OperationKind - (*IndexRequest)(nil), // 1: wundergraph.cosmo.code_mode.yoko.v1.IndexRequest - (*IndexResponse)(nil), // 2: wundergraph.cosmo.code_mode.yoko.v1.IndexResponse - (*SearchRequest)(nil), // 3: wundergraph.cosmo.code_mode.yoko.v1.SearchRequest - (*SearchResponse)(nil), // 4: wundergraph.cosmo.code_mode.yoko.v1.SearchResponse - (*GeneratedOperation)(nil), // 5: wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation + (*IndexSchemaRequest)(nil), // 0: yoko.v1.IndexSchemaRequest + (*IndexSchemaResponse)(nil), // 1: yoko.v1.IndexSchemaResponse + (*GenerateQueryRequest)(nil), // 2: yoko.v1.GenerateQueryRequest + (*GenerateQueryResponse)(nil), // 3: yoko.v1.GenerateQueryResponse + (*Resolution)(nil), // 4: yoko.v1.Resolution + (*ResolvedQuery)(nil), // 5: yoko.v1.ResolvedQuery + (*Unsatisfied)(nil), // 6: yoko.v1.Unsatisfied } var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_depIdxs = []int32{ - 5, // 0: wundergraph.cosmo.code_mode.yoko.v1.SearchResponse.operations:type_name -> wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation - 0, // 1: wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation.kind:type_name -> wundergraph.cosmo.code_mode.yoko.v1.OperationKind - 1, // 2: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index:input_type -> wundergraph.cosmo.code_mode.yoko.v1.IndexRequest - 3, // 3: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search:input_type -> wundergraph.cosmo.code_mode.yoko.v1.SearchRequest - 2, // 4: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index:output_type -> wundergraph.cosmo.code_mode.yoko.v1.IndexResponse - 4, // 5: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search:output_type -> wundergraph.cosmo.code_mode.yoko.v1.SearchResponse - 4, // [4:6] is the sub-list for method output_type - 2, // [2:4] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 4, // 0: yoko.v1.GenerateQueryResponse.resolution:type_name -> yoko.v1.Resolution + 5, // 1: yoko.v1.Resolution.queries:type_name -> yoko.v1.ResolvedQuery + 6, // 2: yoko.v1.Resolution.unsatisfied:type_name -> yoko.v1.Unsatisfied + 0, // 3: yoko.v1.YokoService.IndexSchema:input_type -> yoko.v1.IndexSchemaRequest + 2, // 4: yoko.v1.YokoService.GenerateQuery:input_type -> yoko.v1.GenerateQueryRequest + 1, // 5: yoko.v1.YokoService.IndexSchema:output_type -> yoko.v1.IndexSchemaResponse + 3, // 6: yoko.v1.YokoService.GenerateQuery:output_type -> yoko.v1.GenerateQueryResponse + 5, // [5:7] is the sub-list for method output_type + 3, // [3:5] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name } func init() { file_wg_cosmo_code_mode_yoko_v1_yoko_proto_init() } @@ -435,14 +489,13 @@ func file_wg_cosmo_code_mode_yoko_v1_yoko_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc), len(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc)), - NumEnums: 1, - NumMessages: 5, + NumEnums: 0, + NumMessages: 7, NumExtensions: 0, NumServices: 1, }, GoTypes: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_goTypes, DependencyIndexes: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_depIdxs, - EnumInfos: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes, MessageInfos: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes, }.Build() File_wg_cosmo_code_mode_yoko_v1_yoko_proto = out.File diff --git a/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go index 1e157644aa..9bf06bb9cc 100644 --- a/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go +++ b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go @@ -22,7 +22,7 @@ const _ = connect.IsAtLeastVersion1_13_0 const ( // YokoServiceName is the fully-qualified name of the YokoService service. - YokoServiceName = "wundergraph.cosmo.code_mode.yoko.v1.YokoService" + YokoServiceName = "yoko.v1.YokoService" ) // These constants are the fully-qualified names of the RPCs defined in this package. They're @@ -33,45 +33,50 @@ const ( // reflection-formatted method names, remove the leading slash and convert the remaining slash to a // period. const ( - // YokoServiceIndexProcedure is the fully-qualified name of the YokoService's Index RPC. - YokoServiceIndexProcedure = "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/Index" - // YokoServiceSearchProcedure is the fully-qualified name of the YokoService's Search RPC. - YokoServiceSearchProcedure = "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/Search" + // YokoServiceIndexSchemaProcedure is the fully-qualified name of the YokoService's IndexSchema RPC. + YokoServiceIndexSchemaProcedure = "/yoko.v1.YokoService/IndexSchema" + // YokoServiceGenerateQueryProcedure is the fully-qualified name of the YokoService's GenerateQuery + // RPC. + YokoServiceGenerateQueryProcedure = "/yoko.v1.YokoService/GenerateQuery" ) // These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. var ( - yokoServiceServiceDescriptor = v1.File_wg_cosmo_code_mode_yoko_v1_yoko_proto.Services().ByName("YokoService") - yokoServiceIndexMethodDescriptor = yokoServiceServiceDescriptor.Methods().ByName("Index") - yokoServiceSearchMethodDescriptor = yokoServiceServiceDescriptor.Methods().ByName("Search") + yokoServiceServiceDescriptor = v1.File_wg_cosmo_code_mode_yoko_v1_yoko_proto.Services().ByName("YokoService") + yokoServiceIndexSchemaMethodDescriptor = yokoServiceServiceDescriptor.Methods().ByName("IndexSchema") + yokoServiceGenerateQueryMethodDescriptor = yokoServiceServiceDescriptor.Methods().ByName("GenerateQuery") ) -// YokoServiceClient is a client for the wundergraph.cosmo.code_mode.yoko.v1.YokoService service. +// YokoServiceClient is a client for the yoko.v1.YokoService service. type YokoServiceClient interface { - Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) - Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) + // IndexSchema parses, enriches, embeds and indexes a GraphQL SDL. + // Returns the deterministic schema_id callers pass to GenerateQuery. + IndexSchema(context.Context, *connect.Request[v1.IndexSchemaRequest]) (*connect.Response[v1.IndexSchemaResponse], error) + // GenerateQuery turns a natural-language prompt into one or more + // compiled GraphQL operations against the previously indexed schema. + GenerateQuery(context.Context, *connect.Request[v1.GenerateQueryRequest]) (*connect.Response[v1.GenerateQueryResponse], error) } -// NewYokoServiceClient constructs a client for the wundergraph.cosmo.code_mode.yoko.v1.YokoService -// service. By default, it uses the Connect protocol with the binary Protobuf Codec, asks for -// gzipped responses, and sends uncompressed requests. To use the gRPC or gRPC-Web protocols, supply -// the connect.WithGRPC() or connect.WithGRPCWeb() options. +// NewYokoServiceClient constructs a client for the yoko.v1.YokoService service. By default, it uses +// the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends +// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or +// connect.WithGRPCWeb() options. // // The URL supplied here should be the base URL for the Connect or gRPC server (for example, // http://api.acme.com or https://acme.com/grpc). func NewYokoServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) YokoServiceClient { baseURL = strings.TrimRight(baseURL, "/") return &yokoServiceClient{ - index: connect.NewClient[v1.IndexRequest, v1.IndexResponse]( + indexSchema: connect.NewClient[v1.IndexSchemaRequest, v1.IndexSchemaResponse]( httpClient, - baseURL+YokoServiceIndexProcedure, - connect.WithSchema(yokoServiceIndexMethodDescriptor), + baseURL+YokoServiceIndexSchemaProcedure, + connect.WithSchema(yokoServiceIndexSchemaMethodDescriptor), connect.WithClientOptions(opts...), ), - search: connect.NewClient[v1.SearchRequest, v1.SearchResponse]( + generateQuery: connect.NewClient[v1.GenerateQueryRequest, v1.GenerateQueryResponse]( httpClient, - baseURL+YokoServiceSearchProcedure, - connect.WithSchema(yokoServiceSearchMethodDescriptor), + baseURL+YokoServiceGenerateQueryProcedure, + connect.WithSchema(yokoServiceGenerateQueryMethodDescriptor), connect.WithClientOptions(opts...), ), } @@ -79,25 +84,28 @@ func NewYokoServiceClient(httpClient connect.HTTPClient, baseURL string, opts .. // yokoServiceClient implements YokoServiceClient. type yokoServiceClient struct { - index *connect.Client[v1.IndexRequest, v1.IndexResponse] - search *connect.Client[v1.SearchRequest, v1.SearchResponse] + indexSchema *connect.Client[v1.IndexSchemaRequest, v1.IndexSchemaResponse] + generateQuery *connect.Client[v1.GenerateQueryRequest, v1.GenerateQueryResponse] } -// Index calls wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index. -func (c *yokoServiceClient) Index(ctx context.Context, req *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) { - return c.index.CallUnary(ctx, req) +// IndexSchema calls yoko.v1.YokoService.IndexSchema. +func (c *yokoServiceClient) IndexSchema(ctx context.Context, req *connect.Request[v1.IndexSchemaRequest]) (*connect.Response[v1.IndexSchemaResponse], error) { + return c.indexSchema.CallUnary(ctx, req) } -// Search calls wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search. -func (c *yokoServiceClient) Search(ctx context.Context, req *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) { - return c.search.CallUnary(ctx, req) +// GenerateQuery calls yoko.v1.YokoService.GenerateQuery. +func (c *yokoServiceClient) GenerateQuery(ctx context.Context, req *connect.Request[v1.GenerateQueryRequest]) (*connect.Response[v1.GenerateQueryResponse], error) { + return c.generateQuery.CallUnary(ctx, req) } -// YokoServiceHandler is an implementation of the wundergraph.cosmo.code_mode.yoko.v1.YokoService -// service. +// YokoServiceHandler is an implementation of the yoko.v1.YokoService service. type YokoServiceHandler interface { - Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) - Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) + // IndexSchema parses, enriches, embeds and indexes a GraphQL SDL. + // Returns the deterministic schema_id callers pass to GenerateQuery. + IndexSchema(context.Context, *connect.Request[v1.IndexSchemaRequest]) (*connect.Response[v1.IndexSchemaResponse], error) + // GenerateQuery turns a natural-language prompt into one or more + // compiled GraphQL operations against the previously indexed schema. + GenerateQuery(context.Context, *connect.Request[v1.GenerateQueryRequest]) (*connect.Response[v1.GenerateQueryResponse], error) } // NewYokoServiceHandler builds an HTTP handler from the service implementation. It returns the path @@ -106,24 +114,24 @@ type YokoServiceHandler interface { // By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf // and JSON codecs. They also support gzip compression. func NewYokoServiceHandler(svc YokoServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { - yokoServiceIndexHandler := connect.NewUnaryHandler( - YokoServiceIndexProcedure, - svc.Index, - connect.WithSchema(yokoServiceIndexMethodDescriptor), + yokoServiceIndexSchemaHandler := connect.NewUnaryHandler( + YokoServiceIndexSchemaProcedure, + svc.IndexSchema, + connect.WithSchema(yokoServiceIndexSchemaMethodDescriptor), connect.WithHandlerOptions(opts...), ) - yokoServiceSearchHandler := connect.NewUnaryHandler( - YokoServiceSearchProcedure, - svc.Search, - connect.WithSchema(yokoServiceSearchMethodDescriptor), + yokoServiceGenerateQueryHandler := connect.NewUnaryHandler( + YokoServiceGenerateQueryProcedure, + svc.GenerateQuery, + connect.WithSchema(yokoServiceGenerateQueryMethodDescriptor), connect.WithHandlerOptions(opts...), ) - return "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return "/yoko.v1.YokoService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case YokoServiceIndexProcedure: - yokoServiceIndexHandler.ServeHTTP(w, r) - case YokoServiceSearchProcedure: - yokoServiceSearchHandler.ServeHTTP(w, r) + case YokoServiceIndexSchemaProcedure: + yokoServiceIndexSchemaHandler.ServeHTTP(w, r) + case YokoServiceGenerateQueryProcedure: + yokoServiceGenerateQueryHandler.ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -133,10 +141,10 @@ func NewYokoServiceHandler(svc YokoServiceHandler, opts ...connect.HandlerOption // UnimplementedYokoServiceHandler returns CodeUnimplemented from all methods. type UnimplementedYokoServiceHandler struct{} -func (UnimplementedYokoServiceHandler) Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) { - return nil, connect.NewError(connect.CodeUnimplemented, errors.New("wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index is not implemented")) +func (UnimplementedYokoServiceHandler) IndexSchema(context.Context, *connect.Request[v1.IndexSchemaRequest]) (*connect.Response[v1.IndexSchemaResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("yoko.v1.YokoService.IndexSchema is not implemented")) } -func (UnimplementedYokoServiceHandler) Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) { - return nil, connect.NewError(connect.CodeUnimplemented, errors.New("wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search is not implemented")) +func (UnimplementedYokoServiceHandler) GenerateQuery(context.Context, *connect.Request[v1.GenerateQueryRequest]) (*connect.Response[v1.GenerateQueryResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("yoko.v1.YokoService.GenerateQuery is not implemented")) } diff --git a/router/go.mod b/router/go.mod index 499f3b8aad..67cd1da731 100644 --- a/router/go.mod +++ b/router/go.mod @@ -3,11 +3,11 @@ module github.com/wundergraph/cosmo/router go 1.25.0 require ( - connectrpc.com/connect v1.16.2 + connectrpc.com/connect v1.19.1 github.com/andybalholm/brotli v1.1.0 // indirect github.com/buger/jsonparser v1.1.2 github.com/cespare/xxhash/v2 v2.3.0 - github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a + github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 github.com/dustin/go-humanize v1.0.1 github.com/go-chi/chi/v5 v5.2.2 github.com/go-redis/redis_rate/v10 v10.0.1 @@ -87,24 +87,27 @@ require ( go.uber.org/goleak v1.3.0 go.uber.org/ratelimit v0.3.1 golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 - golang.org/x/net v0.48.0 - golang.org/x/text v0.32.0 + golang.org/x/net v0.49.0 + golang.org/x/text v0.33.0 golang.org/x/time v0.9.0 ) -require github.com/tetratelabs/wazero v1.9.0 // indirect +require ( + github.com/tetratelabs/wazero v1.9.0 // indirect + gotest.tools/v3 v3.5.1 // indirect +) require ( github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/cilium/ebpf v0.9.1 // indirect + github.com/cilium/ebpf v0.16.0 // indirect github.com/coder/websocket v1.8.14 // indirect github.com/containerd/cgroups/v3 v3.0.2 // indirect github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/docker/cli v29.2.0+incompatible // indirect github.com/docker/distribution v2.8.3+incompatible // indirect @@ -114,7 +117,6 @@ require ( github.com/fastschema/qjs v0.0.6 github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/frankban/quicktest v1.14.6 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -144,11 +146,11 @@ require ( github.com/oklog/run v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect - github.com/opencontainers/runtime-spec v1.1.0 // indirect + github.com/opencontainers/runtime-spec v1.2.0 // indirect github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect github.com/phf/go-queue v0.0.0-20170504031614-9abe38d0371d // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect @@ -173,7 +175,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.33.0 // indirect go.opentelemetry.io/proto/otlp v1.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.46.0 // indirect + golang.org/x/crypto v0.47.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect diff --git a/router/go.sum b/router/go.sum index 7e82879a7d..d7db01f657 100644 --- a/router/go.sum +++ b/router/go.sum @@ -1,5 +1,5 @@ -connectrpc.com/connect v1.16.2 h1:ybd6y+ls7GOlb7Bh5C8+ghA6SvCBajHwxssO2CGFjqE= -connectrpc.com/connect v1.16.2/go.mod h1:n2kgwskMHXC+lVqb18wngEpF95ldBHXjZYJussz5FRc= +connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= +connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= connectrpc.com/vanguard v0.3.0 h1:prUKFm8rYDwvpvnOSoqdUowPMK0tRA0pbSrQoMd6Zng= connectrpc.com/vanguard v0.3.0/go.mod h1:nxQ7+N6qhBiQczqGwdTw4oCqx1rDryIt20cEdECqToM= github.com/99designs/gqlgen v0.17.76 h1:YsJBcfACWmXWU2t1yCjoGdOmqcTfOFpjbLAE443fmYI= @@ -39,10 +39,10 @@ github.com/cep21/circuit/v4 v4.0.0 h1:g1AzMmRLuwCst0eccy1nGsD/CL2XKbDnaPUHVHDvVm github.com/cep21/circuit/v4 v4.0.0/go.mod h1:Bb1fHpuiRu+AIgKf7DTM1c5U94qTZtKouKcDwtZYCXk= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cilium/ebpf v0.9.1 h1:64sn2K3UKw8NbP/blsixRpF3nXuyhz/VjRlRzvlBRu4= -github.com/cilium/ebpf v0.9.1/go.mod h1:+OhNOIXx/Fnu1IE8bJz2dzOA+VSfyTfdNUVdlQnxUFY= -github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a h1:8d1CEOF1xldesKds5tRG3tExBsMOgWYownMHNCsev54= -github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= +github.com/cilium/ebpf v0.16.0 h1:+BiEnHL6Z7lXnlGUsXQPPAE7+kenAd4ES8MQ5min0Ok= +github.com/cilium/ebpf v0.16.0/go.mod h1:L7u2Blt2jMM/vLAVgjxluxtBKlz3/GWjB0dMOEngfwE= +github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 h1:pRcxfaAlK0vR6nOeQs7eAEvjJzdGXl8+KaBlcvpQTyQ= +github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/cgroups/v3 v3.0.2 h1:f5WFqIVSgo5IZmtTT3qVBo6TzI1ON6sycSBKkymb9L0= @@ -51,10 +51,10 @@ github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRcc github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/ristretto/v2 v2.4.0 h1:I/w09yLjhdcVD2QV192UJcq8dPBaAJb9pOuMyNy0XlU= github.com/dgraph-io/ristretto/v2 v2.4.0/go.mod h1:0KsrXtXvnv0EqnzyowllbVJB8yBonswa2lTCK2gGo9E= github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= @@ -84,8 +84,6 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= -github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= github.com/go-chi/chi/v5 v5.2.2/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= @@ -97,6 +95,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= +github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo= github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= @@ -169,6 +169,10 @@ github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5Xum github.com/jhump/protoreflect v1.17.0/go.mod h1:h9+vUUL38jiBzck8ck+6G/aeMX8Z4QUY/NiJPwPNi+8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM= +github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE= github.com/kingledion/go-tools v0.6.0 h1:y8C/4mWoHgLkO45dB+Y/j0o4Y4WUB5lDTAcMPMtFpTg= github.com/kingledion/go-tools v0.6.0/go.mod h1:qcDJQxBui/H/hterGb90GMlLs9Yi7QrwaJL8OGdbsms= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -197,6 +201,10 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= +github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= github.com/minio/minio-go/v7 v7.0.74 h1:fTo/XlPBTSpo3BAMshlwKL5RspXRv9us5UeHEGYCFe0= @@ -221,20 +229,20 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= -github.com/opencontainers/runtime-spec v1.1.0 h1:HHUyrt9mwHUjtasSbXSMvs4cyFxh+Bll4AjJ9odEGpg= -github.com/opencontainers/runtime-spec v1.1.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/opencontainers/runtime-spec v1.2.0 h1:z97+pHb3uELt/yiAWD691HNHQIF07bE7dzrbT927iTk= +github.com/opencontainers/runtime-spec v1.2.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= github.com/phf/go-queue v0.0.0-20170504031614-9abe38d0371d h1:U+PMnTlV2tu7RuMK5etusZG3Cf+rpow5hqQByeCzJ2g= github.com/phf/go-queue v0.0.0-20170504031614-9abe38d0371d/go.mod h1:lXfE4PvvTW5xOjO6Mba8zDPyw8M93B6AQ7frTGnMlA8= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posthog/posthog-go v1.5.5 h1:2o3j7IrHbTIfxRtj4MPaXKeimuTYg49onNzNBZbwksM= github.com/posthog/posthog-go v1.5.5/go.mod h1:3RqUmSnPuwmeVj/GYrS75wNGqcAKdpODiwc83xZWgdE= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= @@ -256,7 +264,6 @@ github.com/r3labs/sse/v2 v2.8.1/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEm github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= @@ -395,8 +402,8 @@ go.withmatt.com/connect-brotli v0.4.0 h1:7ObWkYmEbUXK3EKglD0Lgj0BBnnD3jNdAxeDRct go.withmatt.com/connect-brotli v0.4.0/go.mod h1:c2eELz56za+/Mxh1yJrlglZ4VM9krpOCPqS2Vxf8NVk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 h1:SbTAbRFnd5kjQXbczszQ0hdk3ctwYf3qBNH9jIsGclE= golang.org/x/exp v0.0.0-20250813145105-42675adae3e6/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -405,8 +412,8 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -431,8 +438,8 @@ golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= @@ -466,6 +473,6 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= -gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/router/internal/codemode/sandbox/sandbox_test.go b/router/internal/codemode/sandbox/sandbox_test.go index 48e8e8692a..aef9028bc8 100644 --- a/router/internal/codemode/sandbox/sandbox_test.go +++ b/router/internal/codemode/sandbox/sandbox_test.go @@ -116,6 +116,65 @@ func TestExecuteHappyPathToolCall(t *testing.T) { }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) } +func TestExecuteUsesDocumentNameWhenInvokingGraphQL(t *testing.T) { + // Regression: registered Name (camelCase, exposed as tools.) and + // the operation name baked into Body can differ. The host bridge must + // send the document's actual name as `operationName` so the router can + // match the operation definition; otherwise /graphql returns + // "operation with name 'X' not found". + var gotBody map[string]any + client := clientFunc(func(r *http.Request) (*http.Response, error) { + require.NoError(t, json.NewDecoder(r.Body).Decode(&gotBody)) + return jsonResponse(http.StatusOK, `{"data":{"order":{"id":"o1"}}}`), nil + }) + + s := newTestSandbox(t, "http://router/graphql", lookup{ + "getOrder": { + Name: "getOrder", + Body: "query GetOrder($id: ID!) { order(id: $id) { id } }", + Kind: storage.OperationKindQuery, + DocumentName: "GetOrder", + }, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getOrder"}, + WrappedJS: `async () => await tools.getOrder({ id: "o1" })`, + }) + + assert.Equal(t, "GetOrder", gotBody["operationName"]) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":{"order":{"id":"o1"}}}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) +} + +func TestExecuteFallsBackToNameWhenDocumentNameEmpty(t *testing.T) { + // Sessions written before the DocumentName field existed have an empty + // DocumentName. The bridge must fall back to op.Name so legacy entries + // keep working until they age out. + var gotBody map[string]any + client := clientFunc(func(r *http.Request) (*http.Response, error) { + require.NoError(t, json.NewDecoder(r.Body).Decode(&gotBody)) + return jsonResponse(http.StatusOK, `{"data":null}`), nil + }) + + s := newTestSandbox(t, "http://router/graphql", lookup{ + "getOrder": {Name: "getOrder", Body: "query getOrder { order { id } }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getOrder"}, + WrappedJS: `async () => await tools.getOrder()`, + }) + + assert.Equal(t, "getOrder", gotBody["operationName"]) + require.True(t, got.OK) +} + func TestExecuteGraphQLErrorsResolveVerbatimAndRecordSpan(t *testing.T) { client := clientFunc(func(r *http.Request) (*http.Response, error) { return jsonResponse(http.StatusOK, `{"data":null,"errors":[{"message":"x"}]}`), nil diff --git a/router/internal/codemode/server/observability_handler_test.go b/router/internal/codemode/server/observability_handler_test.go index 8a4621326e..a73580b87c 100644 --- a/router/internal/codemode/server/observability_handler_test.go +++ b/router/internal/codemode/server/observability_handler_test.go @@ -24,10 +24,10 @@ import ( func TestHandleSearchRecordsObservability(t *testing.T) { traces, meterProvider, reader := newHandlerTelemetry() searcher := newFakeYoko() - searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ - Name: "getOrders", - Body: "query GetOrders { orders { id } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + searcher.responses <- &yokov1.Resolution{Queries: []*yokov1.ResolvedQuery{{ + OperationName: "getOrders", + Document: "query GetOrders { orders { id } }", + OperationType: "query", }}} store := newSearchTestStorage(t) srv, err := New(Config{ diff --git a/router/internal/codemode/yoko/client.go b/router/internal/codemode/yoko/client.go index d89b09f7bd..7116718d3e 100644 --- a/router/internal/codemode/yoko/client.go +++ b/router/internal/codemode/yoko/client.go @@ -51,15 +51,20 @@ func New(httpClient *http.Client, baseURL string, logger *zap.Logger, opts ...Op return client } -func (c *Client) Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { +// Search resolves prompts against the indexed schema by fanning out one +// GenerateQuery RPC per prompt. The per-prompt Resolutions are merged into a +// single aggregated Resolution. If any RPC returns NotFound (yoko evicted the +// schema), the cached schema_id is invalidated and the entire batch is retried +// once. +func (c *Client) Search(ctx context.Context, prompts []string) (*yokov1.Resolution, error) { schemaID, err := c.ensureSchemaID(ctx) if err != nil { return nil, err } - resp, err := c.search(ctx, schemaID, sessionID, prompts) + resolution, err := c.generateAll(ctx, schemaID, prompts) if err == nil { - return resp, nil + return resolution, nil } if connect.CodeOf(err) != connect.CodeNotFound { return nil, err @@ -72,12 +77,12 @@ func (c *Client) Search(ctx context.Context, sessionID string, prompts []string) return nil, err } - resp, err = c.search(ctx, schemaID, sessionID, prompts) + resolution, err = c.generateAll(ctx, schemaID, prompts) if err != nil { c.invalidateSchemaID(schemaID) return nil, err } - return resp, nil + return resolution, nil } func (c *Client) SetSchema(sdl string) { @@ -87,6 +92,20 @@ func (c *Client) SetSchema(sdl string) { c.schemaID = "" } +// EnsureIndexed sends an IndexSchema RPC for the currently-stored SDL and +// caches the resulting schema_id. It is safe to call eagerly (e.g. from a +// background goroutine after SetSchema) so the first user-facing Search +// doesn't pay the IndexSchema round-trip latency. Concurrent callers +// coalesce on the SDL via the underlying single-flight; if an SDL is already +// indexed, the call is a no-op. With an empty SDL the call is a no-op. +func (c *Client) EnsureIndexed(ctx context.Context) error { + if c.Schema() == "" { + return nil + } + _, err := c.ensureSchemaID(ctx) + return err +} + func (c *Client) Schema() string { c.schemaMu.RLock() defer c.schemaMu.RUnlock() @@ -106,8 +125,8 @@ func (c *Client) ensureSchemaID(ctx context.Context) (string, error) { return currentSchemaID, nil } - resp, err := c.serviceClient.Index(ctx, connect.NewRequest(&yokov1.IndexRequest{ - SchemaSdl: sdl, + resp, err := c.serviceClient.IndexSchema(ctx, connect.NewRequest(&yokov1.IndexSchemaRequest{ + Sdl: sdl, })) if err != nil { return "", err @@ -123,16 +142,27 @@ func (c *Client) ensureSchemaID(ctx context.Context) (string, error) { return value.(string), nil } -func (c *Client) search(ctx context.Context, schemaID string, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { - resp, err := c.serviceClient.Search(ctx, connect.NewRequest(&yokov1.SearchRequest{ - Prompts: prompts, - SchemaId: schemaID, - SessionId: sessionID, - })) - if err != nil { - return nil, err +func (c *Client) generateAll(ctx context.Context, schemaID string, prompts []string) (*yokov1.Resolution, error) { + aggregated := &yokov1.Resolution{} + for _, prompt := range prompts { + resp, err := c.serviceClient.GenerateQuery(ctx, connect.NewRequest(&yokov1.GenerateQueryRequest{ + SchemaId: schemaID, + Prompt: prompt, + })) + if err != nil { + return nil, err + } + r := resp.Msg.GetResolution() + if r == nil { + continue + } + aggregated.Queries = append(aggregated.Queries, r.GetQueries()...) + aggregated.Unsatisfied = append(aggregated.Unsatisfied, r.GetUnsatisfied()...) + if r.GetTruncated() { + aggregated.Truncated = true + } } - return resp.Msg, nil + return aggregated, nil } func (c *Client) schemaState() (string, string) { diff --git a/router/internal/codemode/yoko/client_test.go b/router/internal/codemode/yoko/client_test.go index 136e5193c8..f6f0b1e486 100644 --- a/router/internal/codemode/yoko/client_test.go +++ b/router/internal/codemode/yoko/client_test.go @@ -16,14 +16,14 @@ import ( type fakeYokoServiceClient struct { mu sync.Mutex - indexRequests []*yokov1.IndexRequest - searchRequests []*yokov1.SearchRequest + indexRequests []*yokov1.IndexSchemaRequest + generateRequests []*yokov1.GenerateQueryRequest - indexFunc func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) - searchFunc func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) + indexFunc func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) + generateFunc func(context.Context, *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) } -func (f *fakeYokoServiceClient) Index(ctx context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { +func (f *fakeYokoServiceClient) IndexSchema(ctx context.Context, req *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { f.mu.Lock() f.indexRequests = append(f.indexRequests, req.Msg) indexFunc := f.indexFunc @@ -32,31 +32,31 @@ func (f *fakeYokoServiceClient) Index(ctx context.Context, req *connect.Request[ if indexFunc != nil { return indexFunc(ctx, req) } - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-1"}), nil + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: "schema-1"}), nil } -func (f *fakeYokoServiceClient) Search(ctx context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { +func (f *fakeYokoServiceClient) GenerateQuery(ctx context.Context, req *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) { f.mu.Lock() - f.searchRequests = append(f.searchRequests, req.Msg) - searchFunc := f.searchFunc + f.generateRequests = append(f.generateRequests, req.Msg) + generateFunc := f.generateFunc f.mu.Unlock() - if searchFunc != nil { - return searchFunc(ctx, req) + if generateFunc != nil { + return generateFunc(ctx, req) } - return connect.NewResponse(searchResponse("op")), nil + return connect.NewResponse(generateResponse(req.Msg.GetPrompt())), nil } -func (f *fakeYokoServiceClient) indexRequestMessages() []*yokov1.IndexRequest { +func (f *fakeYokoServiceClient) indexRequestMessages() []*yokov1.IndexSchemaRequest { f.mu.Lock() defer f.mu.Unlock() - return append([]*yokov1.IndexRequest(nil), f.indexRequests...) + return append([]*yokov1.IndexSchemaRequest(nil), f.indexRequests...) } -func (f *fakeYokoServiceClient) searchRequestMessages() []*yokov1.SearchRequest { +func (f *fakeYokoServiceClient) generateRequestMessages() []*yokov1.GenerateQueryRequest { f.mu.Lock() defer f.mu.Unlock() - return append([]*yokov1.SearchRequest(nil), f.searchRequests...) + return append([]*yokov1.GenerateQueryRequest(nil), f.generateRequests...) } func newTestClient(fake *fakeYokoServiceClient) *Client { @@ -65,14 +65,17 @@ func newTestClient(fake *fakeYokoServiceClient) *Client { return client } -func searchResponse(name string) *yokov1.SearchResponse { - return &yokov1.SearchResponse{ - Operations: []*yokov1.GeneratedOperation{ - { - Name: name, - Body: "query " + name + " { product { id } }", - Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, - Description: "Fetch product", +func generateResponse(prompt string) *yokov1.GenerateQueryResponse { + return &yokov1.GenerateQueryResponse{ + Resolution: &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{ + { + Description: "Fetch product for prompt: " + prompt, + Document: "query GetProduct { product { id } }", + OperationName: "GetProduct", + OperationType: "query", + VariablesSchema: `{"type":"object","properties":{}}`, + }, }, }, } @@ -82,225 +85,223 @@ func connectError(code connect.Code, message string) error { return connect.NewError(code, errors.New(message)) } -func TestSearchFirstCallIndexesSchemaThenSearchesWithReturnedID(t *testing.T) { +func TestSearchFirstCallIndexesSchemaThenGeneratesPerPrompt(t *testing.T) { fake := &fakeYokoServiceClient{ - indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-from-yoko"}), nil - }, - searchFunc: func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { - return connect.NewResponse(searchResponse("fromSearch")), nil + indexFunc: func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: "schema-from-yoko"}), nil }, } client := newTestClient(fake) - actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + actual, err := client.Search(context.Background(), []string{"find products", "find more products"}) require.NoError(t, err) - require.Equal(t, searchResponse("fromSearch"), actual) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, - }, fake.indexRequestMessages()) - require.Equal(t, []*yokov1.SearchRequest{ - { - Prompts: []string{"find products"}, - SchemaId: "schema-from-yoko", - SessionId: "session-1", + require.Equal(t, &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{ + generateResponse("find products").GetResolution().GetQueries()[0], + generateResponse("find more products").GetResolution().GetQueries()[0], }, - }, fake.searchRequestMessages()) + }, actual) + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.GenerateQueryRequest{ + {SchemaId: "schema-from-yoko", Prompt: "find products"}, + {SchemaId: "schema-from-yoko", Prompt: "find more products"}, + }, fake.generateRequestMessages()) } func TestSearchSubsequentCallUsesCachedSchemaID(t *testing.T) { fake := &fakeYokoServiceClient{} client := newTestClient(fake) - first, firstErr := client.Search(context.Background(), "session-1", []string{"first"}) - second, secondErr := client.Search(context.Background(), "session-2", []string{"second"}) + first, firstErr := client.Search(context.Background(), []string{"first"}) + second, secondErr := client.Search(context.Background(), []string{"second"}) require.NoError(t, firstErr) require.NoError(t, secondErr) - require.Equal(t, searchResponse("op"), first) - require.Equal(t, searchResponse("op"), second) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, + require.Equal(t, generateResponse("first").GetResolution(), first) + require.Equal(t, generateResponse("second").GetResolution(), second) + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, }, fake.indexRequestMessages()) - require.Equal(t, []*yokov1.SearchRequest{ - { - Prompts: []string{"first"}, - SchemaId: "schema-1", - SessionId: "session-1", - }, - { - Prompts: []string{"second"}, - SchemaId: "schema-1", - SessionId: "session-2", + require.Equal(t, []*yokov1.GenerateQueryRequest{ + {SchemaId: "schema-1", Prompt: "first"}, + {SchemaId: "schema-1", Prompt: "second"}, + }, fake.generateRequestMessages()) +} + +func TestSearchAggregatesResolutionAcrossPrompts(t *testing.T) { + calls := 0 + fake := &fakeYokoServiceClient{ + generateFunc: func(_ context.Context, req *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) { + calls++ + switch calls { + case 1: + return connect.NewResponse(&yokov1.GenerateQueryResponse{ + Resolution: &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{{Document: "q1"}}, + }, + }), nil + case 2: + return connect.NewResponse(&yokov1.GenerateQueryResponse{ + Resolution: &yokov1.Resolution{ + Unsatisfied: []*yokov1.Unsatisfied{{Reason: "no field for that filter"}}, + Truncated: true, + }, + }), nil + } + return connect.NewResponse(&yokov1.GenerateQueryResponse{}), nil }, - }, fake.searchRequestMessages()) + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), []string{"a", "b"}) + + require.NoError(t, err) + require.Equal(t, &yokov1.Resolution{ + Queries: []*yokov1.ResolvedQuery{{Document: "q1"}}, + Unsatisfied: []*yokov1.Unsatisfied{{Reason: "no field for that filter"}}, + Truncated: true, + }, actual) } func TestSearchReindexesAndRetriesOnceAfterNotFound(t *testing.T) { - var searchCount int + var generateCount int fake := &fakeYokoServiceClient{} indexIDs := []string{"schema-initial", "schema-reindexed"} - fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { id := indexIDs[len(fake.indexRequestMessages())-1] - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: id}), nil } - fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { - searchCount++ - if searchCount == 1 { + fake.generateFunc = func(_ context.Context, req *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) { + generateCount++ + if generateCount == 1 { return nil, connectError(connect.CodeNotFound, "schema evicted") } - return connect.NewResponse(searchResponse("retried")), nil + return connect.NewResponse(generateResponse(req.Msg.GetPrompt())), nil } client := newTestClient(fake) - actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + actual, err := client.Search(context.Background(), []string{"find products"}) require.NoError(t, err) - require.Equal(t, searchResponse("retried"), actual) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, - {SchemaSdl: "type Query { product: Product }"}, + require.Equal(t, generateResponse("find products").GetResolution(), actual) + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, + {Sdl: "type Query { product: Product }"}, }, fake.indexRequestMessages()) - require.Equal(t, []*yokov1.SearchRequest{ - { - Prompts: []string{"find products"}, - SchemaId: "schema-initial", - SessionId: "session-1", - }, - { - Prompts: []string{"find products"}, - SchemaId: "schema-reindexed", - SessionId: "session-1", - }, - }, fake.searchRequestMessages()) + require.Equal(t, []*yokov1.GenerateQueryRequest{ + {SchemaId: "schema-initial", Prompt: "find products"}, + {SchemaId: "schema-reindexed", Prompt: "find products"}, + }, fake.generateRequestMessages()) } func TestSearchRetryFailureSurfacesErrorAndLeavesCacheEmpty(t *testing.T) { retryErr := connectError(connect.CodeUnavailable, "retry transport down") indexIDs := []string{"schema-initial", "schema-reindexed", "schema-after-failure"} fake := &fakeYokoServiceClient{} - fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { id := indexIDs[len(fake.indexRequestMessages())-1] - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: id}), nil } - searchErrors := []error{ + generateErrors := []error{ connectError(connect.CodeNotFound, "schema evicted"), retryErr, nil, } - fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { - err := searchErrors[len(fake.searchRequestMessages())-1] + fake.generateFunc = func(_ context.Context, req *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) { + err := generateErrors[len(fake.generateRequestMessages())-1] if err != nil { return nil, err } - return connect.NewResponse(searchResponse("afterFailure")), nil + return connect.NewResponse(generateResponse(req.Msg.GetPrompt())), nil } client := newTestClient(fake) - actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + actual, err := client.Search(context.Background(), []string{"find products"}) require.Nil(t, actual) require.ErrorIs(t, err, retryErr) - actualAfterFailure, errAfterFailure := client.Search(context.Background(), "session-2", []string{"find products again"}) + actualAfterFailure, errAfterFailure := client.Search(context.Background(), []string{"find products again"}) require.NoError(t, errAfterFailure) - require.Equal(t, searchResponse("afterFailure"), actualAfterFailure) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, - {SchemaSdl: "type Query { product: Product }"}, - {SchemaSdl: "type Query { product: Product }"}, + require.Equal(t, generateResponse("find products again").GetResolution(), actualAfterFailure) + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, + {Sdl: "type Query { product: Product }"}, + {Sdl: "type Query { product: Product }"}, }, fake.indexRequestMessages()) - require.Equal(t, []*yokov1.SearchRequest{ - { - Prompts: []string{"find products"}, - SchemaId: "schema-initial", - SessionId: "session-1", - }, - { - Prompts: []string{"find products"}, - SchemaId: "schema-reindexed", - SessionId: "session-1", - }, - { - Prompts: []string{"find products again"}, - SchemaId: "schema-after-failure", - SessionId: "session-2", - }, - }, fake.searchRequestMessages()) + require.Equal(t, []*yokov1.GenerateQueryRequest{ + {SchemaId: "schema-initial", Prompt: "find products"}, + {SchemaId: "schema-reindexed", Prompt: "find products"}, + {SchemaId: "schema-after-failure", Prompt: "find products again"}, + }, fake.generateRequestMessages()) } func TestSearchRetryNotFoundSurfacesErrorAndLeavesCacheEmpty(t *testing.T) { retryErr := connectError(connect.CodeNotFound, "schema evicted again") indexIDs := []string{"schema-initial", "schema-reindexed", "schema-after-failure"} fake := &fakeYokoServiceClient{} - fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { id := indexIDs[len(fake.indexRequestMessages())-1] - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: id}), nil } - searchErrors := []error{ + generateErrors := []error{ connectError(connect.CodeNotFound, "schema evicted"), retryErr, nil, } - fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { - err := searchErrors[len(fake.searchRequestMessages())-1] + fake.generateFunc = func(_ context.Context, req *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) { + err := generateErrors[len(fake.generateRequestMessages())-1] if err != nil { return nil, err } - return connect.NewResponse(searchResponse("afterFailure")), nil + return connect.NewResponse(generateResponse(req.Msg.GetPrompt())), nil } client := newTestClient(fake) - actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + actual, err := client.Search(context.Background(), []string{"find products"}) require.Nil(t, actual) require.ErrorIs(t, err, retryErr) - actualAfterFailure, errAfterFailure := client.Search(context.Background(), "session-2", []string{"find products again"}) + actualAfterFailure, errAfterFailure := client.Search(context.Background(), []string{"find products again"}) require.NoError(t, errAfterFailure) - require.Equal(t, searchResponse("afterFailure"), actualAfterFailure) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, - {SchemaSdl: "type Query { product: Product }"}, - {SchemaSdl: "type Query { product: Product }"}, + require.Equal(t, generateResponse("find products again").GetResolution(), actualAfterFailure) + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, + {Sdl: "type Query { product: Product }"}, + {Sdl: "type Query { product: Product }"}, }, fake.indexRequestMessages()) } func TestSetSchemaInvalidatesCachedIDAndNextSearchReindexes(t *testing.T) { indexIDs := []string{"schema-v1", "schema-v2"} fake := &fakeYokoServiceClient{} - fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { id := indexIDs[len(fake.indexRequestMessages())-1] - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: id}), nil } client := newTestClient(fake) - _, firstErr := client.Search(context.Background(), "session-1", []string{"first"}) + _, firstErr := client.Search(context.Background(), []string{"first"}) client.SetSchema("type Query { review: Review }") - _, secondErr := client.Search(context.Background(), "session-2", []string{"second"}) + _, secondErr := client.Search(context.Background(), []string{"second"}) require.NoError(t, firstErr) require.NoError(t, secondErr) require.Equal(t, "type Query { review: Review }", client.Schema()) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, - {SchemaSdl: "type Query { review: Review }"}, + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, + {Sdl: "type Query { review: Review }"}, }, fake.indexRequestMessages()) - require.Equal(t, []*yokov1.SearchRequest{ - { - Prompts: []string{"first"}, - SchemaId: "schema-v1", - SessionId: "session-1", - }, - { - Prompts: []string{"second"}, - SchemaId: "schema-v2", - SessionId: "session-2", - }, - }, fake.searchRequestMessages()) + require.Equal(t, []*yokov1.GenerateQueryRequest{ + {SchemaId: "schema-v1", Prompt: "first"}, + {SchemaId: "schema-v2", Prompt: "second"}, + }, fake.generateRequestMessages()) } func TestConcurrentFirstSearchIndexesOnce(t *testing.T) { @@ -308,28 +309,28 @@ func TestConcurrentFirstSearchIndexesOnce(t *testing.T) { releaseIndex := make(chan struct{}) var indexStartedOnce sync.Once fake := &fakeYokoServiceClient{ - indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + indexFunc: func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { indexStartedOnce.Do(func() { close(indexStarted) }) <-releaseIndex - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-shared"}), nil + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: "schema-shared"}), nil }, } client := newTestClient(fake) var wg sync.WaitGroup wg.Add(2) - results := make([]*yokov1.SearchResponse, 2) + results := make([]*yokov1.Resolution, 2) errs := make([]error, 2) go func() { defer wg.Done() - results[0], errs[0] = client.Search(context.Background(), "session-1", []string{"first"}) + results[0], errs[0] = client.Search(context.Background(), []string{"first"}) }() <-indexStarted go func() { defer wg.Done() - results[1], errs[1] = client.Search(context.Background(), "session-2", []string{"second"}) + results[1], errs[1] = client.Search(context.Background(), []string{"second"}) }() time.Sleep(25 * time.Millisecond) close(releaseIndex) @@ -337,12 +338,12 @@ func TestConcurrentFirstSearchIndexesOnce(t *testing.T) { require.NoError(t, errs[0]) require.NoError(t, errs[1]) - require.Equal(t, searchResponse("op"), results[0]) - require.Equal(t, searchResponse("op"), results[1]) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, + require.Equal(t, generateResponse("first").GetResolution(), results[0]) + require.Equal(t, generateResponse("second").GetResolution(), results[1]) + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, }, fake.indexRequestMessages()) - assert.Equal(t, 2, len(fake.searchRequestMessages())) + assert.Equal(t, 2, len(fake.generateRequestMessages())) } func TestConcurrentFirstSearchIndexFailureReturnsErrorToBothAndLeavesCacheEmpty(t *testing.T) { @@ -351,7 +352,7 @@ func TestConcurrentFirstSearchIndexFailureReturnsErrorToBothAndLeavesCacheEmpty( releaseIndex := make(chan struct{}) var indexStartedOnce sync.Once fake := &fakeYokoServiceClient{ - indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + indexFunc: func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { indexStartedOnce.Do(func() { close(indexStarted) }) @@ -363,16 +364,16 @@ func TestConcurrentFirstSearchIndexFailureReturnsErrorToBothAndLeavesCacheEmpty( var wg sync.WaitGroup wg.Add(2) - results := make([]*yokov1.SearchResponse, 2) + results := make([]*yokov1.Resolution, 2) errs := make([]error, 2) go func() { defer wg.Done() - results[0], errs[0] = client.Search(context.Background(), "session-1", []string{"first"}) + results[0], errs[0] = client.Search(context.Background(), []string{"first"}) }() <-indexStarted go func() { defer wg.Done() - results[1], errs[1] = client.Search(context.Background(), "session-2", []string{"second"}) + results[1], errs[1] = client.Search(context.Background(), []string{"second"}) }() time.Sleep(25 * time.Millisecond) close(releaseIndex) @@ -382,47 +383,73 @@ func TestConcurrentFirstSearchIndexFailureReturnsErrorToBothAndLeavesCacheEmpty( require.Nil(t, results[1]) require.ErrorIs(t, errs[0], indexErr) require.ErrorIs(t, errs[1], indexErr) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, }, fake.indexRequestMessages()) - require.Equal(t, []*yokov1.SearchRequest(nil), fake.searchRequestMessages()) + require.Equal(t, []*yokov1.GenerateQueryRequest(nil), fake.generateRequestMessages()) - fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { - return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-after-error"}), nil + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: "schema-after-error"}), nil } - actual, err := client.Search(context.Background(), "session-3", []string{"third"}) + actual, err := client.Search(context.Background(), []string{"third"}) require.NoError(t, err) - require.Equal(t, searchResponse("op"), actual) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, - {SchemaSdl: "type Query { product: Product }"}, + require.Equal(t, generateResponse("third").GetResolution(), actual) + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, + {Sdl: "type Query { product: Product }"}, }, fake.indexRequestMessages()) } func TestSearchBubblesUpArbitraryConnectErrors(t *testing.T) { - searchErr := connectError(connect.CodeUnavailable, "search unavailable") + generateErr := connectError(connect.CodeUnavailable, "generate unavailable") fake := &fakeYokoServiceClient{ - searchFunc: func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { - return nil, searchErr + generateFunc: func(context.Context, *connect.Request[yokov1.GenerateQueryRequest]) (*connect.Response[yokov1.GenerateQueryResponse], error) { + return nil, generateErr }, } client := newTestClient(fake) - actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + actual, err := client.Search(context.Background(), []string{"find products"}) require.Nil(t, actual) - require.ErrorIs(t, err, searchErr) - require.Equal(t, []*yokov1.IndexRequest{ - {SchemaSdl: "type Query { product: Product }"}, + require.ErrorIs(t, err, generateErr) + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, }, fake.indexRequestMessages()) - require.Equal(t, []*yokov1.SearchRequest{ - { - Prompts: []string{"find products"}, - SchemaId: "schema-1", - SessionId: "session-1", + require.Equal(t, []*yokov1.GenerateQueryRequest{ + {SchemaId: "schema-1", Prompt: "find products"}, + }, fake.generateRequestMessages()) +} + +func TestEnsureIndexedSendsIndexSchemaAndCachesID(t *testing.T) { + fake := &fakeYokoServiceClient{ + indexFunc: func(context.Context, *connect.Request[yokov1.IndexSchemaRequest]) (*connect.Response[yokov1.IndexSchemaResponse], error) { + return connect.NewResponse(&yokov1.IndexSchemaResponse{SchemaId: "schema-warm"}), nil }, - }, fake.searchRequestMessages()) + } + client := newTestClient(fake) + + require.NoError(t, client.EnsureIndexed(context.Background())) + + // Cached schema_id is reused by the next Search — no second IndexSchema RPC. + _, err := client.Search(context.Background(), []string{"first"}) + require.NoError(t, err) + + require.Equal(t, []*yokov1.IndexSchemaRequest{ + {Sdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.GenerateQueryRequest{ + {SchemaId: "schema-warm", Prompt: "first"}, + }, fake.generateRequestMessages()) +} + +func TestEnsureIndexedNoOpWhenSchemaUnset(t *testing.T) { + fake := &fakeYokoServiceClient{} + client := New(nil, "http://yoko.example", nil, WithServiceClient(fake)) + + require.NoError(t, client.EnsureIndexed(context.Background())) + require.Empty(t, fake.indexRequestMessages()) } func TestSchemaGetterReturnsCurrentSchema(t *testing.T) { diff --git a/router/internal/codemode/yoko/searcher.go b/router/internal/codemode/yoko/searcher.go index 611e5f6fe2..ea1d174542 100644 --- a/router/internal/codemode/yoko/searcher.go +++ b/router/internal/codemode/yoko/searcher.go @@ -7,9 +7,12 @@ import ( ) type Searcher interface { - Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) + Search(ctx context.Context, prompts []string) (*yokov1.Resolution, error) SetSchema(string) Schema() string + // EnsureIndexed proactively warms the schema_id cache so the first + // Search after a (re)load doesn't pay the IndexSchema round-trip. + EnsureIndexed(ctx context.Context) error } var _ Searcher = (*Client)(nil) diff --git a/router/pkg/codemode/varschema/varschema.go b/router/pkg/codemode/varschema/varschema.go new file mode 100644 index 0000000000..28ab68ab3d --- /dev/null +++ b/router/pkg/codemode/varschema/varschema.go @@ -0,0 +1,329 @@ +// Package varschema derives a JSON Schema describing the `$variables` object +// of a GraphQL operation, statically against a parsed schema document. +// +// The generator is shared between the router (which consumes the JSON Schema +// returned by yoko) and the yoko mock (which produces it). It walks the same +// AST shape that the TypeScript bundle renderer uses, but emits JSON Schema +// instead of TS types so the schema is portable across non-TS clients. +package varschema + +import ( + "encoding/json" + "fmt" + "slices" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" +) + +// ForOperation returns a JSON Schema (encoded as a JSON string) that describes +// the `$variables` object accepted by the given GraphQL operation body, +// resolving named types against schema. +func ForOperation(opBody string, schema *ast.Document) (string, error) { + if schema == nil { + return "", fmt.Errorf("variables JSON schema: schema is nil") + } + + opDoc, report := astparser.ParseGraphqlDocumentString(opBody) + if report.HasErrors() { + return "", fmt.Errorf("variables JSON schema: parse operation: %s", report.Error()) + } + + opRef, err := singleOperationRef(&opDoc) + if err != nil { + return "", fmt.Errorf("variables JSON schema: %w", err) + } + + r := renderer{schema: schema} + root, err := r.variablesSchema(&opDoc, opRef) + if err != nil { + return "", fmt.Errorf("variables JSON schema: %w", err) + } + + encoded, err := json.Marshal(root) + if err != nil { + return "", fmt.Errorf("variables JSON schema: encode: %w", err) + } + return string(encoded), nil +} + +func singleOperationRef(doc *ast.Document) (int, error) { + var refs []int + for _, node := range doc.RootNodes { + if node.Kind == ast.NodeKindOperationDefinition { + refs = append(refs, node.Ref) + } + } + if len(refs) == 0 { + return 0, fmt.Errorf("operation document contains no operation definition") + } + if len(refs) > 1 { + return 0, fmt.Errorf("operation document contains %d operation definitions", len(refs)) + } + return refs[0], nil +} + +type renderer struct { + schema *ast.Document +} + +// orderedSchema preserves field declaration order in JSON output. +type orderedSchema struct { + pairs []orderedSchemaEntry +} + +type orderedSchemaEntry struct { + key string + value any +} + +func (o *orderedSchema) set(key string, value any) { + o.pairs = append(o.pairs, orderedSchemaEntry{key: key, value: value}) +} + +func (o orderedSchema) MarshalJSON() ([]byte, error) { + buf := []byte{'{'} + for i, p := range o.pairs { + if i > 0 { + buf = append(buf, ',') + } + k, err := json.Marshal(p.key) + if err != nil { + return nil, err + } + v, err := json.Marshal(p.value) + if err != nil { + return nil, err + } + buf = append(buf, k...) + buf = append(buf, ':') + buf = append(buf, v...) + } + buf = append(buf, '}') + return buf, nil +} + +func (r renderer) variablesSchema(opDoc *ast.Document, opRef int) (orderedSchema, error) { + op := opDoc.OperationDefinitions[opRef] + root := orderedSchema{} + root.set("type", "object") + + if !op.HasVariableDefinitions || len(op.VariableDefinitions.Refs) == 0 { + root.set("properties", orderedSchema{}) + return root, nil + } + + props := orderedSchema{} + required := make([]string, 0, len(op.VariableDefinitions.Refs)) + for _, varRef := range op.VariableDefinitions.Refs { + name := opDoc.VariableDefinitionNameString(varRef) + typeRef := opDoc.VariableDefinitionType(varRef) + isRequired := opDoc.Types[typeRef].TypeKind == ast.TypeKindNonNull + + s, err := r.opType(opDoc, typeRef) + if err != nil { + return orderedSchema{}, err + } + props.set(name, s) + if isRequired { + required = append(required, name) + } + } + root.set("properties", props) + if len(required) > 0 { + root.set("required", required) + } + return root, nil +} + +// opType walks types living in the operation document. The result is a +// nullable JSON Schema fragment unless the type is wrapped in NonNull. +func (r renderer) opType(opDoc *ast.Document, typeRef int) (orderedSchema, error) { + gqlType := opDoc.Types[typeRef] + switch gqlType.TypeKind { + case ast.TypeKindNonNull: + return r.opTypeNonNull(opDoc, gqlType.OfType) + case ast.TypeKindList: + inner, err := r.opType(opDoc, gqlType.OfType) + if err != nil { + return orderedSchema{}, err + } + s := orderedSchema{} + s.set("type", []string{"array", "null"}) + s.set("items", inner) + return s, nil + case ast.TypeKindNamed: + s, err := r.namedType(opDoc.TypeNameString(typeRef)) + if err != nil { + return orderedSchema{}, err + } + return makeNullable(s), nil + default: + return orderedSchema{}, fmt.Errorf("unsupported GraphQL input type kind %s", gqlType.TypeKind.String()) + } +} + +func (r renderer) opTypeNonNull(opDoc *ast.Document, typeRef int) (orderedSchema, error) { + gqlType := opDoc.Types[typeRef] + switch gqlType.TypeKind { + case ast.TypeKindNonNull: + return r.opTypeNonNull(opDoc, gqlType.OfType) + case ast.TypeKindList: + inner, err := r.opType(opDoc, gqlType.OfType) + if err != nil { + return orderedSchema{}, err + } + s := orderedSchema{} + s.set("type", "array") + s.set("items", inner) + return s, nil + case ast.TypeKindNamed: + return r.namedType(opDoc.TypeNameString(typeRef)) + default: + return orderedSchema{}, fmt.Errorf("unsupported GraphQL input type kind %s", gqlType.TypeKind.String()) + } +} + +func (r renderer) namedType(typeName string) (orderedSchema, error) { + s := orderedSchema{} + switch typeName { + case "ID", "String": + s.set("type", "string") + return s, nil + case "Int": + s.set("type", "integer") + return s, nil + case "Float": + s.set("type", "number") + return s, nil + case "Boolean": + s.set("type", "boolean") + return s, nil + } + + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return orderedSchema{}, fmt.Errorf("missing schema type %q", typeName) + } + + switch node.Kind { + case ast.NodeKindEnumTypeDefinition: + s.set("type", "string") + s.set("enum", r.enumValues(node.Ref)) + return s, nil + case ast.NodeKindScalarTypeDefinition: + // Custom scalars: leave the type open. JSON Schema's empty schema {} + // matches anything; we instead emit type:any-of-known to keep clients + // from misvalidating. The simplest acceptable encoding is no `type`. + return s, nil + case ast.NodeKindInputObjectTypeDefinition: + return r.inputObject(node.Ref) + default: + return s, nil + } +} + +func (r renderer) enumValues(enumRef int) []string { + def := r.schema.EnumTypeDefinitions[enumRef] + values := make([]string, 0, len(def.EnumValuesDefinition.Refs)) + for _, valueRef := range def.EnumValuesDefinition.Refs { + values = append(values, r.schema.EnumValueDefinitionNameString(valueRef)) + } + return values +} + +func (r renderer) inputObject(inputObjectRef int) (orderedSchema, error) { + def := r.schema.InputObjectTypeDefinitions[inputObjectRef] + s := orderedSchema{} + s.set("type", "object") + + props := orderedSchema{} + required := make([]string, 0, len(def.InputFieldsDefinition.Refs)) + for _, fieldRef := range def.InputFieldsDefinition.Refs { + name := r.schema.InputValueDefinitionNameString(fieldRef) + typeRef := r.schema.InputValueDefinitionType(fieldRef) + isRequired := r.schema.Types[typeRef].TypeKind == ast.TypeKindNonNull + + field, err := r.schemaType(typeRef) + if err != nil { + return orderedSchema{}, err + } + props.set(name, field) + if isRequired { + required = append(required, name) + } + } + s.set("properties", props) + if len(required) > 0 { + s.set("required", required) + } + return s, nil +} + +// schemaType walks types in the schema document (input fields nested inside +// input objects). Mirrors opType but reads from r.schema. +func (r renderer) schemaType(typeRef int) (orderedSchema, error) { + gqlType := r.schema.Types[typeRef] + switch gqlType.TypeKind { + case ast.TypeKindNonNull: + return r.schemaTypeNonNull(gqlType.OfType) + case ast.TypeKindList: + inner, err := r.schemaType(gqlType.OfType) + if err != nil { + return orderedSchema{}, err + } + s := orderedSchema{} + s.set("type", []string{"array", "null"}) + s.set("items", inner) + return s, nil + case ast.TypeKindNamed: + s, err := r.namedType(r.schema.TypeNameString(typeRef)) + if err != nil { + return orderedSchema{}, err + } + return makeNullable(s), nil + default: + return orderedSchema{}, fmt.Errorf("unsupported GraphQL input type kind %s", gqlType.TypeKind.String()) + } +} + +func (r renderer) schemaTypeNonNull(typeRef int) (orderedSchema, error) { + gqlType := r.schema.Types[typeRef] + switch gqlType.TypeKind { + case ast.TypeKindNonNull: + return r.schemaTypeNonNull(gqlType.OfType) + case ast.TypeKindList: + inner, err := r.schemaType(gqlType.OfType) + if err != nil { + return orderedSchema{}, err + } + s := orderedSchema{} + s.set("type", "array") + s.set("items", inner) + return s, nil + case ast.TypeKindNamed: + return r.namedType(r.schema.TypeNameString(typeRef)) + default: + return orderedSchema{}, fmt.Errorf("unsupported GraphQL input type kind %s", gqlType.TypeKind.String()) + } +} + +// makeNullable widens a JSON Schema to also accept null. If the schema has no +// `type` (e.g. custom scalar with open type), it is returned unchanged. +func makeNullable(s orderedSchema) orderedSchema { + for i, p := range s.pairs { + if p.key != "type" { + continue + } + switch v := p.value.(type) { + case string: + s.pairs[i].value = []string{v, "null"} + case []string: + if !slices.Contains(v, "null") { + s.pairs[i].value = append(v, "null") + } + } + return s + } + return s +} diff --git a/router/pkg/codemode/varschema/varschema_test.go b/router/pkg/codemode/varschema/varschema_test.go new file mode 100644 index 0000000000..adb8e12e0b --- /dev/null +++ b/router/pkg/codemode/varschema/varschema_test.go @@ -0,0 +1,79 @@ +package varschema + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" +) + +const varsTestSchemaSDL = ` +schema { query: Query mutation: Mutation } +type Query { + user(id: ID!): User + users(filter: UserFilter, limit: Int): [User!]! +} +type Mutation { + createUser(input: UserInput!): User! +} +type User { id: ID!, name: String! } +input UserInput { name: String!, age: Int, tags: [String!] } +input UserFilter { name: String, status: Status } +enum Status { ACTIVE INACTIVE } +` + +func TestForOperationNoVariables(t *testing.T) { + schema := mustParseSchema(t, varsTestSchemaSDL) + + got, err := ForOperation(`query Q { user(id: "x") { id } }`, schema) + + require.NoError(t, err) + assert.Equal(t, `{"type":"object","properties":{}}`, got) +} + +func TestForOperationScalarVariables(t *testing.T) { + schema := mustParseSchema(t, varsTestSchemaSDL) + + got, err := ForOperation(`query Q($id: ID!, $limit: Int) { users(limit: $limit) { id } }`, schema) + + require.NoError(t, err) + assert.Equal(t, `{"type":"object","properties":{"id":{"type":"string"},"limit":{"type":["integer","null"]}},"required":["id"]}`, got) +} + +func TestForOperationListVariable(t *testing.T) { + schema := mustParseSchema(t, varsTestSchemaSDL) + + got, err := ForOperation(`query Q($tags: [String!]!) { users { id } }`, schema) + + require.NoError(t, err) + assert.Equal(t, `{"type":"object","properties":{"tags":{"type":"array","items":{"type":"string"}}},"required":["tags"]}`, got) +} + +func TestForOperationInputObjectVariable(t *testing.T) { + schema := mustParseSchema(t, varsTestSchemaSDL) + + got, err := ForOperation(`mutation M($input: UserInput!) { createUser(input: $input) { id } }`, schema) + + require.NoError(t, err) + assert.Equal(t, `{"type":"object","properties":{"input":{"type":"object","properties":{"name":{"type":"string"},"age":{"type":["integer","null"]},"tags":{"type":["array","null"],"items":{"type":"string"}}},"required":["name"]}},"required":["input"]}`, got) +} + +func TestForOperationEnumVariable(t *testing.T) { + schema := mustParseSchema(t, varsTestSchemaSDL) + + got, err := ForOperation(`query Q($status: Status!) { users { id } }`, schema) + + require.NoError(t, err) + assert.Equal(t, `{"type":"object","properties":{"status":{"type":"string","enum":["ACTIVE","INACTIVE"]}},"required":["status"]}`, got) +} + +func mustParseSchema(t *testing.T, sdl string) *ast.Document { + t.Helper() + doc, report := astparser.ParseGraphqlDocumentString(sdl) + require.False(t, report.HasErrors(), report.Error()) + require.NoError(t, asttransform.MergeDefinitionWithBaseSchema(&doc)) + return &doc +} diff --git a/router/pkg/grpcconnector/grpcplugin/grpc_plugin.go b/router/pkg/grpcconnector/grpcplugin/grpc_plugin.go index 269d395035..1b2e728211 100644 --- a/router/pkg/grpcconnector/grpcplugin/grpc_plugin.go +++ b/router/pkg/grpcconnector/grpcplugin/grpc_plugin.go @@ -193,13 +193,13 @@ func (p *GRPCPlugin) Start(ctx context.Context) error { // Stop implements Plugin. func (p *GRPCPlugin) Stop() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.disposed.Load() { return nil } - p.mu.Lock() - defer p.mu.Unlock() - var retErr error if p.client != nil { if err := p.client.Close(); err != nil { From 3bdd22915c7d5aeb21285cf7d099c264fb211f58 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 8 May 2026 10:26:20 +0200 Subject: [PATCH 09/10] fix(code-mode-demo): make subgraphs resilient to NATS unavailability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The mood and availability subgraphs publish employee-update events to NATS as part of their mutation resolvers. Two failure modes broke the demo when NATS was not running (or not yet started): 1. cmd/all crashed at boot subgraphs.New() eagerly created and started two NATS adapters and treated any failure as fatal. Without NATS, the entire process exited with "failed to start default nats adapter". 2. Mutations failed at runtime mood.UpdateMood returned "no nats pubsub default provider found" when the adapter map was empty, and availability.UpdateAvailability nil-panicked on the unconditional NatsPubSubByProviderID["default"] lookup. Even when the data write succeeded (storage.Set runs first), clients saw a downstream error response. Changes: - subgraphs.go: extract startNatsAdapter helper that logs and returns nil on failure. NATS adapter startup and JetStream stream provisioning are now best-effort. - mood/availability resolvers: extract publishMoodEvent / publishAvailabilityEvent helpers that nil-check the adapter and log publish errors rather than returning them. Mutations always succeed if the local storage write succeeds. - code-mode demo start.sh and run_subgraphs_subset.sh: switch from per-subgraph cmd/ processes to a single cmd/all invocation with explicit port flags. The individual cmd/ binaries pass nil for the NATS adapter map; cmd/all wires NATS up correctly when it is available, and now degrades gracefully when it is not. - router-config.yaml: flip require_mutation_approval from true to false so the demo MCP client can run mutations end-to-end without an approval flow. - README.md: document the optional NATS prerequisite (make edfs-infra-up) and explain why the demo runs cmd/all. Trade-off: subscriptions on Employee.currentMood / isAvailable will not deliver updates while NATS is unreachable — direct queries still reflect the new state. The demo prioritizes "queries and mutations always work" over "subscriptions always work". Verified end-to-end with NATS stopped via docker stop cosmo-dev-nats-1: - cmd/all boots cleanly with logged warnings - updateMood and updateAvailability mutations both return successful data responses; resolvers log per-publish skip warnings - after docker start cosmo-dev-nats-1, the same mutations succeed with no resolver-level warnings (publishes go through) Co-Authored-By: Claude Opus 4.7 (1M context) --- demo/code-mode/README.md | 4 +- demo/code-mode/router-config.yaml | 2 +- demo/code-mode/run_subgraphs_subset.sh | 17 +++-- demo/code-mode/start.sh | 18 +++--- .../availability/subgraph/schema.resolvers.go | 38 ++++++----- .../mood/subgraph/schema.resolvers.go | 42 ++++++------- demo/pkg/subgraphs/subgraphs.go | 63 +++++++++---------- 7 files changed, 92 insertions(+), 92 deletions(-) diff --git a/demo/code-mode/README.md b/demo/code-mode/README.md index 9710ef72b1..7ccc8e233f 100644 --- a/demo/code-mode/README.md +++ b/demo/code-mode/README.md @@ -10,6 +10,8 @@ The set mirrors `demo/graph-no-edg.yaml`. The `employeeupdated` subgraph is inte - Node + `pnpm` (used by `wgc` to compose `demo/code-mode/graph.yaml`). - A running Yoko service reachable at `http://127.0.0.1:3400` (override with `YOKO_URL=...`). The router calls Yoko for `code_mode_search_tools`; without it, query generation will fail. +- A running NATS server reachable at `nats://localhost:4222` (override with `NATS_URL=...`). + The `mood` and `availability` mutation resolvers publish to NATS via the `default` provider; without NATS, those mutations fail at runtime with `no nats pubsub default provider found`. Bring it up with `make edfs-infra-up` from the repo root (also starts Kafka — both are part of the `edfs` Docker Compose profile). Tear down with `make edfs-infra-down`. ## Quick start @@ -59,7 +61,7 @@ curl -sS http://localhost:3002/graphql \ ## Other notes -The subset runner is `demo/code-mode/run_subgraphs_subset.sh`. It starts every non-EDFS subgraph used by this demo (`employees`, `family`, `hobbies`, `products`, `test1`, `availability`, `mood`, `countries`, `products_fg`) via `npx concurrently`. The full demo `demo/run_subgraphs.sh` additionally starts the EDFS-dependent `employeeupdated` subgraph and is intentionally not used here. +The subset runner is `demo/code-mode/run_subgraphs_subset.sh`. It runs `demo/cmd/all` with explicit per-subgraph port flags so every non-EDFS subgraph (`employees`, `family`, `hobbies`, `products`, `test1`, `availability`, `mood`, `countries`, `products_fg`) starts in a single process. `cmd/all` wires up the NATS pubsub adapter automatically; the per-subgraph `cmd/` binaries pass `nil` for that adapter and would fail mood/availability mutations at runtime. The full demo `demo/run_subgraphs.sh` additionally starts the EDFS-dependent `employeeupdated` subgraph and is intentionally not used here. Client configuration for Code Mode MCP clients (Claude Code, Claude Desktop, Codex CLI) lives under `demo/code-mode/mcp-configs/` — see the README there. diff --git a/demo/code-mode/router-config.yaml b/demo/code-mode/router-config.yaml index 3c676db755..7bcbb33797 100644 --- a/demo/code-mode/router-config.yaml +++ b/demo/code-mode/router-config.yaml @@ -39,7 +39,7 @@ mcp: # the MCP stdio proxy) get refused — point them at 127.0.0.1 directly # in start.sh and the proxy defaults. listen_addr: 127.0.0.1:5027 - require_mutation_approval: true + require_mutation_approval: false # Sandbox wall-clock cap. Default is 5s (plan §13), which is fine for # compute-only agent code but too short whenever the host blocks the JS # thread on an interactive elicitation. Bump to 180s so a human can review diff --git a/demo/code-mode/run_subgraphs_subset.sh b/demo/code-mode/run_subgraphs_subset.sh index d393e551b9..79447ac14f 100755 --- a/demo/code-mode/run_subgraphs_subset.sh +++ b/demo/code-mode/run_subgraphs_subset.sh @@ -6,13 +6,10 @@ cd "$(dirname "$0")/.." GOCACHE="${GOCACHE:-/tmp/cosmo-code-mode-go-build-cache}" mkdir -p "$GOCACHE" -npx concurrently --kill-others \ - "GOCACHE=$GOCACHE PORT=4001 go run ./cmd/employees" \ - "GOCACHE=$GOCACHE PORT=4002 go run ./cmd/family" \ - "GOCACHE=$GOCACHE PORT=4003 go run ./cmd/hobbies" \ - "GOCACHE=$GOCACHE PORT=4004 go run ./cmd/products" \ - "GOCACHE=$GOCACHE PORT=4006 go run ./cmd/test1" \ - "GOCACHE=$GOCACHE PORT=4007 go run ./cmd/availability" \ - "GOCACHE=$GOCACHE PORT=4008 go run ./cmd/mood" \ - "GOCACHE=$GOCACHE PORT=4009 go run ./cmd/countries" \ - "GOCACHE=$GOCACHE PORT=4010 go run ./cmd/products_fg" +# cmd/all bundles every subgraph into a single process with NATS pubsub +# wired up. Required for mood/availability mutations to work — the per- +# subgraph cmd/ binaries pass nil for the NATS adapter and fail at +# runtime with "no nats pubsub default provider found". +GOCACHE="$GOCACHE" go run ./cmd/all \ + -employees=4001 -family=4002 -hobbies=4003 -products=4004 \ + -test1=4006 -availability=4007 -mood=4008 -countries=4009 -products_fg=4010 diff --git a/demo/code-mode/start.sh b/demo/code-mode/start.sh index ab3e353602..acf498cfd3 100755 --- a/demo/code-mode/start.sh +++ b/demo/code-mode/start.sh @@ -126,15 +126,15 @@ mkdir -p "$GOCACHE_DIR" rm -f "$PID_FILE" trap cleanup EXIT INT TERM -start_background employees "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4001 go run ./cmd/employees -start_background family "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4002 go run ./cmd/family -start_background hobbies "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4003 go run ./cmd/hobbies -start_background products "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4004 go run ./cmd/products -start_background test1 "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4006 go run ./cmd/test1 -start_background availability "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4007 go run ./cmd/availability -start_background mood "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4008 go run ./cmd/mood -start_background countries "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4009 go run ./cmd/countries -start_background products_fg "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4010 go run ./cmd/products_fg +# Use cmd/all so all subgraphs run in one process with the NATS pubsub +# adapter wired up. The individual cmd/ binaries pass nil for the +# NATS adapter map, which makes mood/availability mutations fail at +# runtime with "no nats pubsub default provider found". +# NATS_URL falls back to nats://localhost:4222 — bring NATS up via +# `make edfs-infra-up` before running this demo. +start_background subgraphs "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" go run ./cmd/all \ + -employees=4001 -family=4002 -hobbies=4003 -products=4004 \ + -test1=4006 -availability=4007 -mood=4008 -countries=4009 -products_fg=4010 wait_url employees http://localhost:4001/ wait_url family http://localhost:4002/ diff --git a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go index 97ef578631..84b03b845b 100644 --- a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go @@ -7,6 +7,7 @@ package subgraph import ( "context" "fmt" + "log" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/generated" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/model" @@ -14,28 +15,31 @@ import ( "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) -// UpdateAvailability is the resolver for the updateAvailability field. -func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID int, isAvailable bool) (*model.Employee, error) { - storage.Set(employeeID, isAvailable) - conf := &nats.PublishAndRequestEventConfiguration{ - Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), +// publishAvailabilityEvent emits an Employee-updated event for subscription +// consumers. Failures (missing adapter, broker down) are logged but never +// fail the mutation — local storage has already been updated. +func (r *mutationResolver) publishAvailabilityEvent(ctx context.Context, providerID, subject, payload string) { + adapter := r.NatsPubSubByProviderID[providerID] + if adapter == nil { + log.Printf("availability: nats provider %q unavailable, skipping publish to %s", providerID, subject) + return } - evt := &nats.MutableEvent{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} - - err := r.NatsPubSubByProviderID["default"].Publish(ctx, conf, []datasource.StreamEvent{evt}) + err := adapter.Publish(ctx, &nats.PublishAndRequestEventConfiguration{ + Subject: subject, + }, []datasource.StreamEvent{&nats.MutableEvent{Data: []byte(payload)}}) if err != nil { - return nil, err + log.Printf("availability: nats publish failed via %q to %s: %v", providerID, subject, err) } +} - conf2 := &nats.PublishAndRequestEventConfiguration{ - Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), - } - evt2 := &nats.MutableEvent{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} - err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, conf2, []datasource.StreamEvent{evt2}) +// UpdateAvailability is the resolver for the updateAvailability field. +func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID int, isAvailable bool) (*model.Employee, error) { + storage.Set(employeeID, isAvailable) + payload := fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID) + + r.publishAvailabilityEvent(ctx, "default", r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), payload) + r.publishAvailabilityEvent(ctx, "my-nats", r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), payload) - if err != nil { - return nil, err - } return &model.Employee{ID: employeeID, IsAvailable: &isAvailable}, nil } diff --git a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go index 17ef56e9ed..05bcf10f88 100644 --- a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go @@ -7,6 +7,7 @@ package subgraph import ( "context" "fmt" + "log" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/generated" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/model" @@ -14,33 +15,30 @@ import ( "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) +// publishMoodEvent emits an Employee-updated event for subscription consumers. +// Failures (missing adapter, broker down) are logged but never fail the +// mutation — local storage has already been updated. +func (r *mutationResolver) publishMoodEvent(ctx context.Context, providerID, subject, payload string) { + adapter := r.NatsPubSubByProviderID[providerID] + if adapter == nil { + log.Printf("mood: nats provider %q unavailable, skipping publish to %s", providerID, subject) + return + } + err := adapter.Publish(ctx, &nats.PublishAndRequestEventConfiguration{ + Subject: subject, + }, []datasource.StreamEvent{&nats.MutableEvent{Data: []byte(payload)}}) + if err != nil { + log.Printf("mood: nats publish failed via %q to %s: %v", providerID, subject, err) + } +} + // UpdateMood is the resolver for the updateMood field. func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood model.Mood) (*model.Employee, error) { storage.Set(employeeID, mood) - myNatsTopic := r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)) payload := fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID) - if r.NatsPubSubByProviderID["default"] != nil { - err := r.NatsPubSubByProviderID["default"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ - Subject: myNatsTopic, - }, []datasource.StreamEvent{&nats.MutableEvent{Data: []byte(payload)}}) - if err != nil { - return nil, err - } - } else { - return nil, fmt.Errorf("no nats pubsub default provider found") - } - defaultTopic := r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)) - if r.NatsPubSubByProviderID["my-nats"] != nil { - err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ - Subject: defaultTopic, - }, []datasource.StreamEvent{&nats.MutableEvent{Data: []byte(payload)}}) - if err != nil { - return nil, err - } - } else { - return nil, fmt.Errorf("no nats pubsub my-nats provider found") - } + r.publishMoodEvent(ctx, "default", r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), payload) + r.publishMoodEvent(ctx, "my-nats", r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), payload) return &model.Employee{ID: employeeID, CurrentMood: mood}, nil } diff --git a/demo/pkg/subgraphs/subgraphs.go b/demo/pkg/subgraphs/subgraphs.go index 44764f14bd..8190a9a980 100644 --- a/demo/pkg/subgraphs/subgraphs.go +++ b/demo/pkg/subgraphs/subgraphs.go @@ -204,51 +204,50 @@ func CountriesHandler(opts *SubgraphOptions) http.Handler { return subgraphHandler(countries.NewSchema(opts.NatsPubSubByProviderID)) } -func New(ctx context.Context, config *Config) (*Subgraphs, error) { - url := nats.DefaultURL - if defaultSourceNameURL := os.Getenv("NATS_URL"); defaultSourceNameURL != "" { - url = defaultSourceNameURL - } - - natsPubSubByProviderID := map[string]natsPubsub.Adapter{} - - defaultAdapter, err := natsPubsub.NewAdapter(ctx, zap.NewNop(), url, []nats.Option{}, "hostname", "test", false, datasource.ProviderOpts{ +// startNatsAdapter creates and starts a single NATS adapter. Failures are +// logged and nil is returned so callers can continue without NATS — pubsub +// publishes simply become no-ops in the resolvers. +func startNatsAdapter(ctx context.Context, providerID, url string) natsPubsub.Adapter { + adapter, err := natsPubsub.NewAdapter(ctx, zap.NewNop(), url, []nats.Option{}, "hostname", "test", false, datasource.ProviderOpts{ StreamMetricStore: rmetric.NewNoopStreamMetricStore(), }) if err != nil { - return nil, fmt.Errorf("failed to create default nats adapter: %w", err) + log.Printf("nats adapter %q unavailable: create failed: %v", providerID, err) + return nil } - if err := defaultAdapter.Startup(ctx); err != nil { - return nil, fmt.Errorf("failed to start default nats adapter: %w", err) + if err := adapter.Startup(ctx); err != nil { + log.Printf("nats adapter %q unavailable: startup failed: %v", providerID, err) + return nil } - natsPubSubByProviderID["default"] = defaultAdapter + return adapter +} - myNatsAdapter, err := natsPubsub.NewAdapter(ctx, zap.NewNop(), url, []nats.Option{}, "hostname", "test", false, datasource.ProviderOpts{ - StreamMetricStore: rmetric.NewNoopStreamMetricStore(), - }) - if err != nil { - return nil, fmt.Errorf("failed to create my-nats adapter: %w", err) - } - if err := myNatsAdapter.Startup(ctx); err != nil { - return nil, fmt.Errorf("failed to start my-nats adapter: %w", err) +func New(ctx context.Context, config *Config) (*Subgraphs, error) { + url := nats.DefaultURL + if defaultSourceNameURL := os.Getenv("NATS_URL"); defaultSourceNameURL != "" { + url = defaultSourceNameURL } - natsPubSubByProviderID["my-nats"] = myNatsAdapter - defaultConnection, err := nats.Connect(url) - if err != nil { - log.Printf("failed to connect to nats source \"nats\": %v", err) + natsPubSubByProviderID := map[string]natsPubsub.Adapter{} + if a := startNatsAdapter(ctx, "default", url); a != nil { + natsPubSubByProviderID["default"] = a } - defaultJetStream, err := jetstream.New(defaultConnection) - if err != nil { - return nil, err + if a := startNatsAdapter(ctx, "my-nats", url); a != nil { + natsPubSubByProviderID["my-nats"] = a } - _, err = defaultJetStream.CreateOrUpdateStream(ctx, jetstream.StreamConfig{ + // JetStream stream provisioning is also best-effort — when NATS is not + // reachable we just skip it. The subgraphs' subscription functionality + // will be unavailable, but plain queries and mutations keep working. + if defaultConnection, err := nats.Connect(url); err != nil { + log.Printf("nats: skipping jetstream stream provisioning (connect failed): %v", err) + } else if defaultJetStream, err := jetstream.New(defaultConnection); err != nil { + log.Printf("nats: skipping jetstream stream provisioning (jetstream init failed): %v", err) + } else if _, err := defaultJetStream.CreateOrUpdateStream(ctx, jetstream.StreamConfig{ Name: "streamName", Subjects: []string{"employeeUpdated.>"}, - }) - if err != nil { - return nil, err + }); err != nil { + log.Printf("nats: skipping jetstream stream provisioning (CreateOrUpdateStream failed): %v", err) } var servers []*http.Server From df3562bcfbc599bd899d2ec23ab1a57ad5bd0878 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 8 May 2026 10:56:12 +0200 Subject: [PATCH 10/10] docs(code-mode): split unrelated mutations into separate search prompts Bundling unrelated writes produced tangled operations with mixed argument shapes; the search-tool description now requires one prompt per logical write and reserves combination for tightly-correlated cascades. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../server/descriptions/search_tool.md | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/router/internal/codemode/server/descriptions/search_tool.md b/router/internal/codemode/server/descriptions/search_tool.md index 64581a38e1..5183478bd6 100644 --- a/router/internal/codemode/server/descriptions/search_tool.md +++ b/router/internal/codemode/server/descriptions/search_tool.md @@ -18,9 +18,24 @@ Always state: - Any required filters/arguments but never specific values ("employee by id - not "employee 123", "employee filtered by department name" - not "employee in department 'Engineering'"). - Concrete entity and relationship names from the domain when you know them; otherwise describe the relationship explicitly ("the team an employee belongs to"). -When to use multiple prompts (rare): genuinely unrelated operations on disjoint domains, different argument shapes that can't share a parent, or queries vs mutations. +When to use multiple prompts (rare for reads): genuinely unrelated operations on disjoint domains, different argument shapes that can't share a parent, or queries vs mutations. Never slice one joinable shape into fragments. -When in doubt, combine. +When in doubt for reads, combine. + +MUTATIONS ARE DIFFERENT — DEFAULT TO ONE PROMPT PER LOGICAL WRITE. +Mutations have side effects and are imperative, not joinable. +Bundling unrelated writes into one prompt produces tangled operations with mixed argument shapes and unclear failure semantics. +Issue a SEPARATE prompt for each mutation that is not tightly correlated with the others. +Tightly correlated means: same target entity (e.g. update name + update email on the same user), a parent/child cascade that must be authored together (create order + add line items to that order), or writes that share the same input shape and variables. +Unrelated mutations on different entities, different argument shapes, or independently triggered by the user MUST be issued as separate prompts — even if you are calling them in the same code_mode_run_js batch. +A read prompt and a write prompt never share a single search prompt; describe reads and writes separately. + +Mutation example (correct, two prompts): +- "mutation: update employee by id, set forename and surname; return the updated employee with id, forename, surname" +- "mutation: archive a project by id; return the archived project with id, status, archivedAt" + +Mutation example (correct, one prompt — tightly correlated cascade): +- "mutation: create a project with title and ownerId, then add an array of tasks (each with title, dueDate) to the new project; return the project with id, title and its tasks with id, title, dueDate" Do NOT issue prompts for derived/computed values: averages, medians, counts, filters, exclusions ("without X"), sorting, top-N. Fetch the raw rows once and compute in code_mode_run_js.