diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 98eb145..92247cd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,92 +2,15 @@ name: CI on: pull_request: - push: - branches: - - main permissions: contents: read -defaults: - run: - shell: bash +concurrency: + group: ci-${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: - test: - name: Build and test - runs-on: ubuntu-24.04 - env: - IMPORT_NAME: src_py_lib - PYTHON_VERSION: "3.11" - UV_VERSION: "0.11.7" - - steps: - - name: Check out code - uses: actions/checkout@v6 - with: - persist-credentials: false - - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: ${{ env.PYTHON_VERSION }} - cache: pip - - - name: Install uv - run: | - python -m pip install --upgrade pip - python -m pip install "uv==${UV_VERSION}" - - - name: Validate lockfile - run: uv lock --check - - - name: Lint Markdown - run: npx --yes markdownlint-cli2 - - - name: Lint Python - run: uv run --frozen ruff check . - - - name: Check Python formatting - run: uv run --frozen ruff format --check . - - - name: Type check - run: uv run --frozen pyright - - - name: Run tests - run: uv run --frozen python -m unittest discover -s tests - - - name: Smoke test source checkout import - run: | - uv run --frozen python - <<'PY' - import os - - import src_py_lib - - if src_py_lib.__name__ != os.environ["IMPORT_NAME"]: - raise SystemExit(f"unexpected import name: {src_py_lib.__name__}") - PY - - - name: Build wheel - run: uv build --wheel --out-dir dist --no-create-gitignore - - - name: Smoke test installed wheel - run: | - python -m venv build/ci-venv - . build/ci-venv/bin/activate - python -m pip install --upgrade pip - python -m pip install dist/*.whl - python - <<'PY' - import os - - import src_py_lib - - if src_py_lib.__name__ != os.environ["IMPORT_NAME"]: - raise SystemExit(f"unexpected import name: {src_py_lib.__name__}") - PY - - - name: Upload wheel artifact - uses: actions/upload-artifact@v7 - with: - name: src-py-lib-wheel - path: dist/*.whl + validate: + name: Validate + uses: ./.github/workflows/validate.yml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0d8c9c6..8d2fc27 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -23,8 +23,15 @@ defaults: shell: bash jobs: + validate: + name: Validate + uses: ./.github/workflows/validate.yml + with: + ref: ${{ github.event.inputs.tag || github.ref }} + wheel: name: Build wheel + needs: validate runs-on: ubuntu-24.04 env: IMPORT_NAME: src_py_lib @@ -45,6 +52,14 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} cache: pip + - name: Cache uv + uses: actions/cache@v4 + with: + path: ~/.cache/uv + key: uv-${{ runner.os }}-py${{ env.PYTHON_VERSION }}-${{ hashFiles('uv.lock') }} + restore-keys: | + uv-${{ runner.os }}-py${{ env.PYTHON_VERSION }}- + - name: Install build tools run: | python -m pip install --upgrade pip @@ -62,6 +77,14 @@ jobs: echo "::error title=Missing tag::Tag '${release_tag}' was not fetched. Create and push it before running this workflow." exit 1 fi + tag_revision="$(git rev-list -n 1 "${release_tag}")" + git fetch --no-tags origin main + main_revision="$(git rev-parse origin/main)" + if ! git merge-base --is-ancestor "${tag_revision}" "${main_revision}"; then + echo "::error title=Tag is not on main::Tag '${release_tag}' points at ${tag_revision}, which is not reachable from origin/main." + echo "::error::Merge the release PR first, then tag the main commit." + exit 1 + fi project_version=$(uv run --frozen python - <<'PY' import tomllib @@ -77,22 +100,6 @@ jobs: echo "tag=${release_tag}" >> "${GITHUB_OUTPUT}" - - name: Validate package - run: | - uv lock --check - uv run --frozen ruff check . - uv run --frozen ruff format --check . - uv run --frozen pyright - uv run --frozen python -m unittest discover -s tests - uv run --frozen python - <<'PY' - import os - - import src_py_lib - - if src_py_lib.__name__ != os.environ["IMPORT_NAME"]: - raise SystemExit(f"unexpected import name: {src_py_lib.__name__}") - PY - - name: Build distributions id: build run: | @@ -115,17 +122,24 @@ jobs: wheel_path="${project_wheels[0]}" wheel_name="$(basename "${wheel_path}")" source_distribution_path="${source_distributions[0]}" - checksum_path="${wheel_path}.sha256" + source_distribution_name="$(basename "${source_distribution_path}")" + wheel_checksum_path="${wheel_path}.sha256" + source_distribution_checksum_path="${source_distribution_path}.sha256" ( cd "$(dirname "${wheel_path}")" - shasum -a 256 "${wheel_name}" > "$(basename "${checksum_path}")" + shasum -a 256 "${wheel_name}" > "$(basename "${wheel_checksum_path}")" + shasum -a 256 "${source_distribution_name}" > "$(basename "${source_distribution_checksum_path}")" ) - echo "wheel_path=${wheel_path}" >> "${GITHUB_OUTPUT}" - echo "wheel_name=${wheel_name}" >> "${GITHUB_OUTPUT}" - echo "source_distribution_path=${source_distribution_path}" >> "${GITHUB_OUTPUT}" - echo "checksum_path=${checksum_path}" >> "${GITHUB_OUTPUT}" + { + echo "wheel_path=${wheel_path}" + echo "wheel_name=${wheel_name}" + echo "source_distribution_path=${source_distribution_path}" + echo "source_distribution_name=${source_distribution_name}" + echo "wheel_checksum_path=${wheel_checksum_path}" + echo "source_distribution_checksum_path=${source_distribution_checksum_path}" + } >> "${GITHUB_OUTPUT}" - name: Smoke test installed wheel run: | @@ -147,6 +161,7 @@ jobs: run: | release_tag="${{ steps.release.outputs.tag }}" wheel_name="${{ steps.build.outputs.wheel_name }}" + source_distribution_name="${{ steps.build.outputs.source_distribution_name }}" notes_path="build/release/release-notes.md" cat > "${notes_path}" <> "${GITHUB_OUTPUT}" @@ -179,7 +200,9 @@ jobs: name: src-py-lib-release path: | ${{ steps.build.outputs.wheel_path }} - ${{ steps.build.outputs.checksum_path }} + ${{ steps.build.outputs.source_distribution_path }} + ${{ steps.build.outputs.wheel_checksum_path }} + ${{ steps.build.outputs.source_distribution_checksum_path }} ${{ steps.notes.outputs.path }} - name: Upload PyPI artifact @@ -196,16 +219,23 @@ jobs: run: | release_tag="${{ steps.release.outputs.tag }}" wheel_path="${{ steps.build.outputs.wheel_path }}" - checksum_path="${{ steps.build.outputs.checksum_path }}" + source_distribution_path="${{ steps.build.outputs.source_distribution_path }}" + wheel_checksum_path="${{ steps.build.outputs.wheel_checksum_path }}" + source_distribution_checksum_path="${{ steps.build.outputs.source_distribution_checksum_path }}" notes_path="${{ steps.notes.outputs.path }}" + release_assets=( + "${wheel_path}" + "${source_distribution_path}" + "${wheel_checksum_path}" + "${source_distribution_checksum_path}" + ) if gh release view "${release_tag}" >/dev/null 2>&1; then gh release edit "${release_tag}" --title "${release_tag}" --notes-file "${notes_path}" - gh release upload "${release_tag}" "${wheel_path}" "${checksum_path}" --clobber + gh release upload "${release_tag}" "${release_assets[@]}" --clobber else gh release create "${release_tag}" \ - "${wheel_path}" \ - "${checksum_path}" \ + "${release_assets[@]}" \ --title "${release_tag}" \ --notes-file "${notes_path}" \ --verify-tag diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml new file mode 100644 index 0000000..0920e0f --- /dev/null +++ b/.github/workflows/validate.yml @@ -0,0 +1,124 @@ +name: Validate + +on: + workflow_call: + inputs: + ref: + description: "Git ref to validate. Defaults to the caller's ref." + required: false + type: string + +permissions: + contents: read + +defaults: + run: + shell: bash + +jobs: + package: + name: Validate package + runs-on: ubuntu-24.04 + env: + ACTIONLINT_VERSION: "1.7.12" + IMPORT_NAME: src_py_lib + MARKDOWNLINT_CLI2_VERSION: "0.22.1" + PYTHON_VERSION: "3.11" + UV_VERSION: "0.11.7" + + steps: + - name: Check out code + uses: actions/checkout@v6 + with: + persist-credentials: false + ref: ${{ inputs.ref || github.ref }} + + - name: Cache actionlint + id: cache-actionlint + uses: actions/cache@v4 + with: + path: ~/.local/bin/actionlint + key: actionlint-${{ runner.os }}-${{ runner.arch }}-${{ env.ACTIONLINT_VERSION }} + + - name: Install actionlint + if: steps.cache-actionlint.outputs.cache-hit != 'true' + run: | + mkdir -p "${HOME}/.local/bin" + go install "github.com/rhysd/actionlint/cmd/actionlint@v${ACTIONLINT_VERSION}" + install -m 0755 "${HOME}/go/bin/actionlint" "${HOME}/.local/bin/actionlint" + + - name: Lint GitHub Actions + run: | + "${HOME}/.local/bin/actionlint" + + - name: Cache npm + uses: actions/cache@v4 + with: + path: ~/.npm + key: npm-${{ runner.os }}-markdownlint-cli2-${{ env.MARKDOWNLINT_CLI2_VERSION }} + + - name: Lint Markdown + run: npx --yes "markdownlint-cli2@${MARKDOWNLINT_CLI2_VERSION}" + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + + - name: Cache uv + uses: actions/cache@v4 + with: + path: ~/.cache/uv + key: uv-${{ runner.os }}-py${{ env.PYTHON_VERSION }}-${{ hashFiles('uv.lock') }} + restore-keys: | + uv-${{ runner.os }}-py${{ env.PYTHON_VERSION }}- + + - name: Install uv + run: | + python -m pip install --upgrade pip + python -m pip install "uv==${UV_VERSION}" + + - name: Validate lockfile + run: uv lock --check + + - name: Lint Python + run: uv run --frozen ruff check . + + - name: Check Python formatting + run: uv run --frozen ruff format --check . + + - name: Type check + run: uv run --frozen pyright + + - name: Run tests + run: uv run --frozen python -m unittest discover -s tests + + - name: Smoke test source checkout import + run: | + uv run --frozen python - <<'PY' + import os + + import src_py_lib + + if src_py_lib.__name__ != os.environ["IMPORT_NAME"]: + raise SystemExit(f"unexpected import name: {src_py_lib.__name__}") + PY + + - name: Build wheel + run: uv build --wheel --out-dir dist --no-create-gitignore + + - name: Smoke test installed wheel + run: | + python -m venv build/ci-venv + . build/ci-venv/bin/activate + python -m pip install --upgrade pip + python -m pip install dist/*.whl + python - <<'PY' + import os + + import src_py_lib + + if src_py_lib.__name__ != os.environ["IMPORT_NAME"]: + raise SystemExit(f"unexpected import name: {src_py_lib.__name__}") + PY diff --git a/AGENTS.md b/AGENTS.md index 01f9076..550004f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,7 +14,8 @@ ## Standard commands ```sh -npx --yes markdownlint-cli2 +actionlint +npx --yes markdownlint-cli2@0.22.1 uv sync uv run ruff format . uv run ruff check . @@ -56,6 +57,9 @@ uv run python -m unittest discover -s tests verifies that it matches `project.version` before building GitHub release assets and publishing to PyPI. - Prepare releases on a branch from current `main`. Set `VERSION`, then run: +- As part of every release bump, find old release-version literals in + `AGENTS.md`, `README.md`, and release snippets, and replace them with the + new version where they are meant to stay current. ```sh set -euo pipefail @@ -96,7 +100,8 @@ uv lock set -euo pipefail uv lock --check -npx --yes markdownlint-cli2 +actionlint +npx --yes markdownlint-cli2@0.22.1 uv run ruff check . uv run ruff format --check . uv run pyright diff --git a/README.md b/README.md index 08816de..1c2e446 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,13 @@ This repo was created for Sourcegraph Implementation Engineering deployments, and is not intended, designed, built, or supported for use in any other scenario. Feel free to open issues or PRs, but responses are best effort. +## Semantic Versioning + +- Release versions are `major.minor.patch` +- Because this project is still major version 0: + - Minor version updates are breaking changes + - Patch version updates are not breaking changes + ## Install From PyPI: diff --git a/pyproject.toml b/pyproject.toml index 664ad0a..3366a4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ Issues = "https://github.com/sourcegraph/src-py-lib/issues" packages = ["src/src_py_lib"] [tool.pyright] -include = ["src/src_py_lib"] +include = ["src/src_py_lib", "tests"] typeCheckingMode = "strict" extraPaths = ["src"] pythonVersion = "3.11" diff --git a/src/src_py_lib/__init__.py b/src/src_py_lib/__init__.py index 66e2710..85a6ace 100644 --- a/src/src_py_lib/__init__.py +++ b/src/src_py_lib/__init__.py @@ -37,6 +37,14 @@ from src_py_lib.clients.sourcegraph import ( SourcegraphClient, SourcegraphClientConfig, + SourcegraphJaegerTraceError, + SourcegraphJaegerTraceSummary, + SourcegraphTrace, + decode_external_service_id, + decode_repository_id, + decode_sourcegraph_node_id, + encode_repository_id, + encode_sourcegraph_node_id, normalize_sourcegraph_endpoint, sourcegraph_client_from_config, ) @@ -49,7 +57,7 @@ from src_py_lib.utils.config import ( config_parse_args as parse_args, ) -from src_py_lib.utils.http import HTTPClient, HTTPClientError +from src_py_lib.utils.http import HTTPClient, HTTPClientError, HTTPResponse from src_py_lib.utils.json_cache import load_json_cache, load_json_subset, save_json_cache from src_py_lib.utils.json_types import ( JSONDict, @@ -63,8 +71,10 @@ from src_py_lib.utils.logging import ( LoggingConfig, LoggingSettings, + TraceContext, configure_logging, critical, + current_trace_context, debug, error, event, @@ -73,10 +83,15 @@ log_context, logging_context, logging_settings_from_config, + new_trace_context, resolve_log_level_name, + sampled_traceparent, stage, startup_event, submit_with_log_context, + trace_context, + trace_context_from_traceparent, + traceparent_header, warning, ) from src_py_lib.utils.tsv import write_tsv @@ -116,6 +131,7 @@ def _script_name() -> str: "GoogleSheetsError", "HTTPClient", "HTTPClientError", + "HTTPResponse", "JSONDict", "LinearClient", "LinearClientConfig", @@ -128,12 +144,22 @@ def _script_name() -> str: "SlackPacer", "SourcegraphClient", "SourcegraphClientConfig", + "SourcegraphJaegerTraceError", + "SourcegraphJaegerTraceSummary", + "SourcegraphTrace", + "TraceContext", "aliased_batched_query", "config_field", "config_snapshot", "configure_logging", "critical", + "current_trace_context", "debug", + "decode_external_service_id", + "decode_repository_id", + "decode_sourcegraph_node_id", + "encode_repository_id", + "encode_sourcegraph_node_id", "error", "event", "gh_cli_token", @@ -153,18 +179,23 @@ def _script_name() -> str: "logging_settings_from_config", "log", "log_context", + "new_trace_context", "normalize_sourcegraph_endpoint", "parse_args", "pr_ref_from_url", "quota_project_from_adc", "resolve_log_level_name", "save_json_cache", + "sampled_traceparent", "slack_client_from_config", "sourcegraph_client_from_config", "stage", "startup_event", "stream_connection_nodes", "submit_with_log_context", + "trace_context", + "trace_context_from_traceparent", + "traceparent_header", "warning", "write_tsv", ] diff --git a/src/src_py_lib/clients/graphql.py b/src/src_py_lib/clients/graphql.py index fe4e226..4dbf143 100644 --- a/src/src_py_lib/clients/graphql.py +++ b/src/src_py_lib/clients/graphql.py @@ -9,11 +9,13 @@ from pathlib import Path from typing import cast -from src_py_lib.utils.http import HTTPClient, HTTPClientError +from src_py_lib.utils.http import HTTPClient, HTTPClientError, HTTPResponse from src_py_lib.utils.json_types import JSONDict, JSONValue, json_dict, json_list, json_str from src_py_lib.utils.logging import event _OPERATION_NAME_RE = re.compile(r"\b(?:query|mutation|subscription)\s+(\w+)") +HeaderProvider = Mapping[str, str] | Callable[[], Mapping[str, str]] +GraphQLResponseHook = Callable[[HTTPResponse, Mapping[str, str]], None] GRAPHQL_INTROSPECTION_QUERY = """ query IntrospectionQuery { @@ -130,10 +132,11 @@ class GraphQLClient: """POST JSON GraphQL operations and return the `data` object.""" url: str - headers: dict[str, str] + headers: HeaderProvider label: str http: HTTPClient = field(default_factory=HTTPClient) tolerate_partial_errors: bool = False + response_hook: GraphQLResponseHook | None = None def execute( self, @@ -251,7 +254,16 @@ def _execute_once( query_bytes=len(query.encode("utf-8")), ) as fields: try: - payload = self.http.json("POST", self.url, headers=self.headers, json_body=body) + request_headers = self._headers() + if self.response_hook is None: + payload = self.http.json( + "POST", self.url, headers=request_headers, json_body=body + ) + else: + payload, response = self.http.json_response( + "POST", self.url, headers=request_headers, json_body=body + ) + self.response_hook(response, request_headers) except HTTPClientError as exception: raise GraphQLError( f"{self.label} GraphQL request failed: {exception}", @@ -269,6 +281,11 @@ def _execute_once( ) return data + def _headers(self) -> dict[str, str]: + if callable(self.headers): + return dict(self.headers()) + return dict(self.headers) + def operation_name(query: str) -> str: """Extract the operation name from a GraphQL document.""" diff --git a/src/src_py_lib/clients/sourcegraph.py b/src/src_py_lib/clients/sourcegraph.py index ec9b158..b980fb1 100644 --- a/src/src_py_lib/clients/sourcegraph.py +++ b/src/src_py_lib/clients/sourcegraph.py @@ -2,16 +2,37 @@ from __future__ import annotations -from collections.abc import Iterator, Mapping, Sequence +import base64 +import collections +import json +import queue +import time +from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass, field +from typing import Final, cast from urllib.parse import urlsplit from src_py_lib.clients.graphql import GraphQLClient, stream_connection_nodes from src_py_lib.utils.config import Config, config_field -from src_py_lib.utils.http import HTTPClient -from src_py_lib.utils.json_types import JSONDict, JSONValue, json_dict +from src_py_lib.utils.http import HTTPClient, HTTPClientError, HTTPResponse +from src_py_lib.utils.json_types import JSONDict, JSONValue, json_dict, json_list +from src_py_lib.utils.logging import ( + current_trace_context, + new_trace_context, + trace_context_from_traceparent, + traceparent_header, +) DEFAULT_SOURCEGRAPH_ENDPOINT = "https://sourcegraph.com" +SOURCEGRAPH_EXTERNAL_SERVICE_NODE_TYPE: Final[str] = "ExternalService" +SOURCEGRAPH_REPOSITORY_NODE_TYPE: Final[str] = "Repository" +REQUEST_TRACE_HEADER: Final[str] = "X-Sourcegraph-Request-Trace" +TRACEPARENT_HEADER: Final[str] = "traceparent" +TRACE_ID_RESPONSE_HEADER: Final[str] = "x-trace" +TRACE_SPAN_RESPONSE_HEADER: Final[str] = "x-trace-span" +TRACE_URL_RESPONSE_HEADER: Final[str] = "x-trace-url" +JAEGER_TRACE_RETRY_DELAYS_SECONDS: Final[tuple[float, ...]] = (0.0, 2.0, 5.0) +RETRYABLE_JAEGER_TRACE_STATUS_CODES: Final[frozenset[int]] = frozenset({404, 502, 503, 504}) SOURCEGRAPH_VALIDATE_QUERY = """ query SourcegraphClientValidate { currentUser { @@ -21,6 +42,58 @@ """ +class SourcegraphJaegerTraceError(RuntimeError): + """Raised when a Sourcegraph Jaeger/debug trace cannot be fetched.""" + + +@dataclass(frozen=True) +class SourcegraphTrace: + """Trace metadata Sourcegraph returned for one traced request.""" + + trace_id: str + span_id: str | None = None + trace_url: str | None = None + parent_trace_id: str | None = None + parent_span_id: str | None = None + + def to_json(self) -> JSONDict: + payload: JSONDict = {"trace_id": self.trace_id} + if self.span_id is not None: + payload["span_id"] = self.span_id + if self.trace_url is not None: + payload["trace_url"] = self.trace_url + if self.parent_trace_id is not None: + payload["parent_trace_id"] = self.parent_trace_id + if self.parent_span_id is not None: + payload["parent_span_id"] = self.parent_span_id + return payload + + +@dataclass(frozen=True) +class SourcegraphJaegerTraceSummary: + """Compact summary of one Sourcegraph Jaeger/debug trace.""" + + trace: SourcegraphTrace + jaeger_found: bool + span_count: int = 0 + hot_operations: tuple[JSONDict, ...] = () + graphql_operations: tuple[JSONDict, ...] = () + errored_spans: tuple[JSONDict, ...] = () + error: str = "" + + def to_json(self) -> JSONDict: + payload = self.trace.to_json() + payload["jaeger_found"] = self.jaeger_found + if not self.jaeger_found: + payload["error"] = self.error + return payload + payload["span_count"] = self.span_count + payload["hot_operations"] = [dict(operation) for operation in self.hot_operations] + payload["graphql_operations"] = [dict(operation) for operation in self.graphql_operations] + payload["errored_spans"] = [dict(span) for span in self.errored_spans] + return payload + + def normalize_sourcegraph_endpoint(endpoint: str, *, require_https: bool = False) -> str: """Return a stable Sourcegraph base URL, or raise ValueError.""" normalized_endpoint = endpoint.strip().rstrip("/") @@ -41,6 +114,44 @@ def normalize_sourcegraph_endpoint(endpoint: str, *, require_https: bool = False return normalized_endpoint +def encode_sourcegraph_node_id(node_type: str, database_id: int) -> str: + """Return a Sourcegraph opaque GraphQL Node ID for `node_type:database_id`.""" + raw = f"{node_type}:{database_id}".encode() + return base64.b64encode(raw).decode() + + +def decode_sourcegraph_node_id(node_type: str, graphql_id: str) -> int: + """Return the database ID from a Sourcegraph opaque GraphQL Node ID.""" + try: + raw = base64.b64decode(graphql_id, validate=True).decode() + except (ValueError, UnicodeDecodeError) as exception: + raise ValueError(f"not a valid base64 GraphQL Node ID: {graphql_id!r}") from exception + decoded_node_type, separator, database_id = raw.partition(":") + if not separator or decoded_node_type != node_type: + raise ValueError(f"not a {node_type} Node ID: {graphql_id!r} (decoded: {raw!r})") + try: + return int(database_id) + except ValueError as exception: + raise ValueError( + f"{node_type} Node ID has non-integer suffix: {graphql_id!r} (decoded: {raw!r})" + ) from exception + + +def decode_external_service_id(graphql_id: str) -> int: + """Return the database ID from an opaque ExternalService GraphQL Node ID.""" + return decode_sourcegraph_node_id(SOURCEGRAPH_EXTERNAL_SERVICE_NODE_TYPE, graphql_id) + + +def encode_repository_id(database_id: int) -> str: + """Return an opaque Repository GraphQL Node ID from a database ID.""" + return encode_sourcegraph_node_id(SOURCEGRAPH_REPOSITORY_NODE_TYPE, database_id) + + +def decode_repository_id(graphql_id: str) -> int: + """Return the database ID from an opaque Repository GraphQL Node ID.""" + return decode_sourcegraph_node_id(SOURCEGRAPH_REPOSITORY_NODE_TYPE, graphql_id) + + class SourcegraphClientConfig(Config): """Config fields needed to build a Sourcegraph API client.""" @@ -68,11 +179,20 @@ class SourcegraphClient: `endpoint` should be the instance base URL, for example `https://sourcegraph.example.com`. + + Set `trace=True` to ask Sourcegraph to retain traces for each GraphQL + request. Traced requests are available through `drain_traces()` and can be + fetched from the instance's Jaeger/debug endpoint with + `stream_jaeger_trace_summaries()`. """ endpoint: str token: str http: HTTPClient = field(default_factory=HTTPClient) + trace: bool = False + _traces: queue.Queue[SourcegraphTrace] = field( + default_factory=lambda: queue.Queue[SourcegraphTrace](), init=False, repr=False + ) def __post_init__(self) -> None: self.endpoint = normalize_sourcegraph_endpoint(self.endpoint) @@ -110,18 +230,243 @@ def validate(self) -> JSONDict: ) return current_user + def drain_traces(self) -> list[SourcegraphTrace]: + """Return traced request metadata recorded since the last drain.""" + traces: list[SourcegraphTrace] = [] + while True: + try: + traces.append(self._traces.get_nowait()) + except queue.Empty: + return traces + + def stream_jaeger_trace_summaries( + self, + traces: Iterable[SourcegraphTrace] | None = None, + *, + retry_delays_seconds: Sequence[float] = JAEGER_TRACE_RETRY_DELAYS_SECONDS, + ) -> Iterator[SourcegraphJaegerTraceSummary]: + """Yield compact Jaeger/debug summaries for traced Sourcegraph requests.""" + for trace in self.drain_traces() if traces is None else traces: + yield self.fetch_jaeger_trace_summary( + trace, + retry_delays_seconds=retry_delays_seconds, + ) + + def fetch_jaeger_trace_summary( + self, + trace: SourcegraphTrace | str, + *, + retry_delays_seconds: Sequence[float] = JAEGER_TRACE_RETRY_DELAYS_SECONDS, + ) -> SourcegraphJaegerTraceSummary: + """Fetch one Jaeger/debug trace and return a compact summary.""" + trace_metadata = trace if isinstance(trace, SourcegraphTrace) else SourcegraphTrace(trace) + try: + jaeger_trace = self.fetch_jaeger_trace( + trace_metadata.trace_id, + retry_delays_seconds=retry_delays_seconds, + ) + except SourcegraphJaegerTraceError as error: + return SourcegraphJaegerTraceSummary( + trace=trace_metadata, + jaeger_found=False, + error=str(error), + ) + return summarize_jaeger_trace(trace_metadata, jaeger_trace) + + def fetch_jaeger_trace( + self, + trace_id: str, + *, + retry_delays_seconds: Sequence[float] = JAEGER_TRACE_RETRY_DELAYS_SECONDS, + ) -> JSONDict: + """Fetch a raw Jaeger/debug trace from the Sourcegraph instance.""" + url = f"{self.endpoint}/-/debug/jaeger/api/traces/{trace_id}" + last_error = "trace not found" + for delay_seconds in retry_delays_seconds: + if delay_seconds > 0: + time.sleep(delay_seconds) + try: + payload = self.http.json("GET", url, headers=self._authorization_headers()) + except HTTPClientError as error: + last_error = sourcegraph_trace_fetch_error(error) + if ( + error.status_code is None + or error.status_code in RETRYABLE_JAEGER_TRACE_STATUS_CODES + ): + continue + raise SourcegraphJaegerTraceError(last_error) from error + for trace_value in json_list(payload.get("data")): + jaeger_trace = json_dict(trace_value) + if jaeger_trace: + return jaeger_trace + errors = payload.get("errors") + last_error = json.dumps(errors) if errors else "trace not found" + raise SourcegraphJaegerTraceError(last_error) + def _client(self) -> GraphQLClient: return GraphQLClient( url=f"{self.endpoint}/.api/graphql", - headers={"Authorization": f"token {self.token}"}, + headers=self._graphql_headers, label="Sourcegraph", http=self.http, + response_hook=self._record_trace_response if self.trace else None, ) + def _authorization_headers(self) -> dict[str, str]: + return {"Authorization": f"token {self.token}"} + + def _graphql_headers(self) -> dict[str, str]: + headers = self._authorization_headers() + if self.trace: + headers[REQUEST_TRACE_HEADER] = "true" + headers[TRACEPARENT_HEADER] = traceparent_header( + current_trace_context() or new_trace_context() + ) + return headers + + def _record_trace_response( + self, response: HTTPResponse, request_headers: Mapping[str, str] + ) -> None: + trace = sourcegraph_trace_from_headers(response.headers, request_headers) + if trace is not None: + self._traces.put(trace) + -def sourcegraph_client_from_config(config: SourcegraphClientConfig) -> SourcegraphClient: +def sourcegraph_client_from_config( + config: SourcegraphClientConfig, + *, + http: HTTPClient | None = None, + trace: bool = False, +) -> SourcegraphClient: """Return a Sourcegraph API client from shared Sourcegraph Config fields.""" return SourcegraphClient( endpoint=config.src_endpoint, token=config.src_access_token, + http=http or HTTPClient(), + trace=trace, + ) + + +def sampled_traceparent() -> str: + """Compatibility wrapper for sampled W3C traceparent generation.""" + return traceparent_header(sampled=True) + + +def sourcegraph_trace_from_headers( + response_headers: Mapping[str, str], request_headers: Mapping[str, str] +) -> SourcegraphTrace | None: + """Return Sourcegraph trace metadata from request/response headers.""" + trace_id = header_value(response_headers, TRACE_ID_RESPONSE_HEADER) + if trace_id is None or not is_hex_identifier(trace_id, 32): + return None + span_id = header_value(response_headers, TRACE_SPAN_RESPONSE_HEADER) + trace_url = header_value(response_headers, TRACE_URL_RESPONSE_HEADER) + parent = trace_context_from_traceparent(header_value(request_headers, TRACEPARENT_HEADER)) + return SourcegraphTrace( + trace_id=trace_id.lower(), + span_id=span_id.lower() if span_id and is_hex_identifier(span_id, 16) else span_id, + trace_url=trace_url, + parent_trace_id=parent.trace_id if parent is not None else None, + parent_span_id=parent.span_id if parent is not None else None, + ) + + +def is_hex_identifier(value: str, length: int) -> bool: + """Return whether `value` is a non-zero hex identifier of `length` characters.""" + lowered = value.lower() + return ( + len(lowered) == length + and any(character != "0" for character in lowered) + and all(character in "0123456789abcdef" for character in lowered) ) + + +def header_value(headers: Mapping[str, str], name: str) -> str | None: + """Return one header value by case-insensitive name.""" + lower_name = name.lower() + for header_name, value in headers.items(): + if header_name.lower() == lower_name: + return value + return None + + +def sourcegraph_trace_fetch_error(error: HTTPClientError) -> str: + """Return a concise, user-safe Jaeger trace fetch error.""" + if error.status_code is None: + return str(error) + return f"HTTP {error.status_code}" + (f": {error.body[:200]}" if error.body else "") + + +def summarize_jaeger_trace( + trace_metadata: SourcegraphTrace, jaeger_trace: JSONDict +) -> SourcegraphJaegerTraceSummary: + """Return a compact summary of one raw Jaeger trace payload.""" + spans = json_list(jaeger_trace.get("spans")) + durations_by_operation: dict[str, list[float]] = collections.defaultdict(list) + graphql_operations: collections.Counter[str] = collections.Counter() + errored_spans: list[JSONDict] = [] + + for span_value in spans: + span = json_dict(span_value) + if not span: + continue + operation = str(span.get("operationName") or "") + duration_ms = float_value(span.get("duration")) / 1000.0 + durations_by_operation[operation].append(duration_ms) + tags = jaeger_span_tags(span) + operation_name = tags.get("graphql.operationName") + if isinstance(operation_name, str): + graphql_operations[operation_name] += 1 + if tags.get("error") in {True, "true", "True"}: + errored_spans.append( + { + "operation": operation, + "duration_ms": round(duration_ms, 1), + "description": json_scalar(tags.get("otel.status_description")), + } + ) + + hot_operations = [ + { + "operation": operation, + "count": len(durations), + "sum_ms": round(sum(durations), 1), + "max_ms": round(max(durations), 1), + } + for operation, durations in durations_by_operation.items() + ] + hot_operations.sort(key=lambda operation: float(operation["sum_ms"]), reverse=True) + return SourcegraphJaegerTraceSummary( + trace=trace_metadata, + jaeger_found=True, + span_count=len(spans), + hot_operations=tuple(cast(JSONDict, operation) for operation in hot_operations[:10]), + graphql_operations=tuple( + {"operation": operation, "count": count} + for operation, count in graphql_operations.most_common(10) + ), + errored_spans=tuple(errored_spans[:5]), + ) + + +def jaeger_span_tags(span: JSONDict) -> dict[str, object]: + """Return Jaeger span tags keyed by tag name.""" + tags: dict[str, object] = {} + for tag_value in json_list(span.get("tags")): + tag = json_dict(tag_value) + key = tag.get("key") + if isinstance(key, str): + tags[key] = tag.get("value") + return tags + + +def float_value(value: object) -> float: + """Return a JSON number as float, excluding booleans.""" + return float(value) if isinstance(value, int | float) and not isinstance(value, bool) else 0.0 + + +def json_scalar(value: object) -> JSONValue: + """Return `value` if it is a JSON scalar; otherwise return None.""" + if value is None or isinstance(value, bool | int | float | str): + return value + return None diff --git a/src/src_py_lib/utils/http.py b/src/src_py_lib/utils/http.py index 4b2ffe9..ba65348 100644 --- a/src/src_py_lib/utils/http.py +++ b/src/src_py_lib/utils/http.py @@ -54,6 +54,21 @@ def __init__( self.headers = {key.lower(): value for key, value in dict(headers or {}).items()} +@dataclass(frozen=True) +class HTTPResponse: + """HTTP response data returned by `HTTPClient`.""" + + status_code: int + reason_phrase: str + headers: dict[str, str] + content: bytes + http_version: str | None = None + + def header(self, name: str) -> str | None: + """Return one response header by case-insensitive name.""" + return self.headers.get(name.lower()) + + @dataclass class HTTPClient: """HTTPX-backed HTTP client for JSON APIs with pooled connections.""" @@ -103,6 +118,26 @@ def request( data: bytes | None = None, ) -> bytes: """Make an HTTP request and return raw response bytes.""" + return self.response( + method, + url, + headers=headers, + query=query, + json_body=json_body, + data=data, + ).content + + def response( + self, + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + query: Mapping[str, str | int | float | bool | None] | None = None, + json_body: object | None = None, + data: bytes | None = None, + ) -> HTTPResponse: + """Make an HTTP request and return response headers plus raw bytes.""" request_url = _with_query(url, query) body = data request_headers = {"User-Agent": self.user_agent, **dict(headers or {})} @@ -152,7 +187,13 @@ def request( record_http_retry() self._sleep_before_retry(attempt, response.headers.get("Retry-After")) else: - return payload + return HTTPResponse( + status_code=response.status_code, + reason_phrase=response.reason_phrase, + headers=_response_headers(response.headers), + content=payload, + http_version=http_version, + ) except HTTPClientError: raise except httpx.TransportError as exception: @@ -179,9 +220,30 @@ def json( json_body: object | None = None, ) -> JSONDict: """Make an HTTP request and decode a JSON object response.""" - raw = self.request(method, url, headers=headers, query=query, json_body=json_body) + payload, _response = self.json_response( + method, + url, + headers=headers, + query=query, + json_body=json_body, + ) + return payload + + def json_response( + self, + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + query: Mapping[str, str | int | float | bool | None] | None = None, + json_body: object | None = None, + ) -> tuple[JSONDict, HTTPResponse]: + """Make an HTTP request and return a JSON object plus response metadata.""" + response = self.response(method, url, headers=headers, query=query, json_body=json_body) try: - return json_dict(json.loads(raw.decode("utf-8")) if raw else {}) + return json_dict( + json.loads(response.content.decode("utf-8")) if response.content else {} + ), response except json.JSONDecodeError as exception: raise HTTPClientError( f"Invalid JSON response from {method} {_safe_url(url)}" @@ -240,6 +302,10 @@ def _header_items(headers: Mapping[str, str] | httpx.Headers) -> Iterable[tuple[ return headers.items() +def _response_headers(headers: httpx.Headers) -> dict[str, str]: + return {name.lower(): value for name, value in headers.items()} + + def _is_sensitive_header(name: str) -> bool: lowered = name.lower() return any(fragment in lowered for fragment in SENSITIVE_HEADER_FRAGMENTS) diff --git a/src/src_py_lib/utils/logging.py b/src/src_py_lib/utils/logging.py index 569f70c..7522fe0 100644 --- a/src/src_py_lib/utils/logging.py +++ b/src/src_py_lib/utils/logging.py @@ -40,7 +40,8 @@ SRC_LOG_VERBOSE: Final[str] = "SRC_LOG_VERBOSE" SRC_LOG_QUIET: Final[str] = "SRC_LOG_QUIET" SRC_LOG_SILENT: Final[str] = "SRC_LOG_SILENT" -TRACE_SPAN_BYTES: Final[int] = 4 +TRACE_ID_BYTES: Final[int] = 16 +SPAN_ID_BYTES: Final[int] = 8 MEBIBYTE: Final[int] = 1024 * 1024 SECRET_FIELD_FRAGMENTS: Final[tuple[str, ...]] = ( "api_key", @@ -188,13 +189,48 @@ def logging_settings_from_config( @dataclass(frozen=True) -class _SpanContext: - trace: str - span: str - parent_span: str | None = None +class TraceContext: + """W3C-compatible trace/span identifiers for logs and outbound requests.""" + trace_id: str + span_id: str + parent_span_id: str | None = None -_SPAN_CONTEXT: contextvars.ContextVar[_SpanContext | None] = contextvars.ContextVar( + def __post_init__(self) -> None: + if not _is_hex_identifier(self.trace_id, TRACE_ID_BYTES * 2): + raise ValueError("trace_id must be a non-zero 32-character hex string") + if not _is_hex_identifier(self.span_id, SPAN_ID_BYTES * 2): + raise ValueError("span_id must be a non-zero 16-character hex string") + if self.parent_span_id is not None and not _is_hex_identifier( + self.parent_span_id, SPAN_ID_BYTES * 2 + ): + raise ValueError("parent_span_id must be a non-zero 16-character hex string") + + @property + def trace(self) -> str: + """Return the log-field trace identifier.""" + return self.trace_id + + @property + def span(self) -> str: + """Return the log-field span identifier.""" + return self.span_id + + @property + def parent_span(self) -> str | None: + """Return the log-field parent span identifier.""" + return self.parent_span_id + + def child(self) -> TraceContext: + """Return a child span in the same trace.""" + return new_trace_context(self) + + def traceparent(self, *, sampled: bool = True) -> str: + """Return this context as a W3C traceparent header value.""" + return traceparent_header(self, sampled=sampled) + + +_SPAN_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar( "src_py_lib_span_context", default=None ) @@ -597,6 +633,68 @@ def stage(name: str, **fields: Any) -> Generator[None]: yield +def current_trace_context() -> TraceContext | None: + """Return the current logging trace/span context, if one is active.""" + return _SPAN_CONTEXT.get() + + +def new_trace_context(parent: TraceContext | None = None) -> TraceContext: + """Return a root or child trace/span context. + + When `parent` is omitted, the current context is used as the parent when + available. Otherwise a new root trace is created. + """ + resolved_parent = parent if parent is not None else current_trace_context() + if resolved_parent is None: + return TraceContext( + trace_id=_nonzero_hex(TRACE_ID_BYTES), + span_id=_nonzero_hex(SPAN_ID_BYTES), + ) + return TraceContext( + trace_id=resolved_parent.trace_id, + span_id=_nonzero_hex(SPAN_ID_BYTES), + parent_span_id=resolved_parent.span_id, + ) + + +@contextlib.contextmanager +def trace_context(context: TraceContext | None = None) -> Generator[TraceContext]: + """Set a trace/span context for nested logs and outbound requests.""" + resolved_context = context or new_trace_context() + reset_token = _SPAN_CONTEXT.set(resolved_context) + try: + yield resolved_context + finally: + _SPAN_CONTEXT.reset(reset_token) + + +def traceparent_header(context: TraceContext | None = None, *, sampled: bool = True) -> str: + """Return a W3C traceparent header for `context` or the current context.""" + resolved_context = context or current_trace_context() or new_trace_context() + flags = "01" if sampled else "00" + return f"00-{resolved_context.trace_id}-{resolved_context.span_id}-{flags}" + + +def sampled_traceparent(context: TraceContext | None = None) -> str: + """Return a sampled W3C traceparent header value.""" + return traceparent_header(context, sampled=True) + + +def trace_context_from_traceparent(value: str | None) -> TraceContext | None: + """Return trace/span identifiers parsed from a W3C traceparent header.""" + if value is None: + return None + parts = value.split("-") + if len(parts) != 4 or parts[0] != "00": + return None + trace_id = parts[1].lower() + span_id = parts[2].lower() + try: + return TraceContext(trace_id=trace_id, span_id=span_id) + except ValueError: + return None + + @contextlib.contextmanager def event( key: str, @@ -608,12 +706,7 @@ def event( **fields: Any, ) -> Generator[dict[str, Any]]: """Emit start/end structured events around a block of work.""" - parent = _SPAN_CONTEXT.get() - span = _SpanContext( - trace=parent.trace if parent else secrets.token_hex(TRACE_SPAN_BYTES), - span=secrets.token_hex(TRACE_SPAN_BYTES), - parent_span=parent.span if parent else None, - ) + span = new_trace_context() reset_token = _SPAN_CONTEXT.set(span) try: log(start_level or level, key, logger_name=logger_name, phase="start", **fields) @@ -862,6 +955,22 @@ def _decode_http_bytes(value: object) -> str | None: return None +def _nonzero_hex(byte_count: int) -> str: + while True: + value = secrets.token_hex(byte_count) + if any(character != "0" for character in value): + return value + + +def _is_hex_identifier(value: str, length: int) -> bool: + lowered = value.lower() + return ( + len(lowered) == length + and any(character != "0" for character in lowered) + and all(character in "0123456789abcdef" for character in lowered) + ) + + def _secret_state(value: object) -> str: if value is None or value == "": return "missing" diff --git a/tests/test_logging_http_clients.py b/tests/test_logging_http_clients.py index 0566ca1..ef703cd 100644 --- a/tests/test_logging_http_clients.py +++ b/tests/test_logging_http_clients.py @@ -36,6 +36,9 @@ from src_py_lib.clients.sourcegraph import ( SourcegraphClient, SourcegraphClientConfig, + decode_external_service_id, + decode_repository_id, + encode_repository_id, normalize_sourcegraph_endpoint, sourcegraph_client_from_config, ) @@ -53,7 +56,7 @@ load_config_from_args, resolve_config_refs, ) -from src_py_lib.utils.http import HTTPClient, HTTPClientError +from src_py_lib.utils.http import HTTPClient, HTTPClientError, HTTPResponse from src_py_lib.utils.json_types import JSONDict, json_dict, json_list from src_py_lib.utils.logging import ( LoggingConfig, @@ -975,8 +978,8 @@ def test_event_context_adds_trace_and_span_fields(self) -> None: ) self.assertEqual(outer_start["trace"], outer_end["trace"]) self.assertEqual(outer_start["span"], outer_end["span"]) - self.assertEqual(len(outer_start["trace"]), 8) - self.assertEqual(len(outer_start["span"]), 8) + self.assertEqual(len(outer_start["trace"]), 32) + self.assertEqual(len(outer_start["span"]), 16) self.assertNotIn("parent_span", outer_start) self.assertEqual(inside["trace"], outer_start["trace"]) @@ -997,7 +1000,7 @@ def test_event_context_adds_trace_and_span_fields(self) -> None: ) self.assertEqual(inner_start["trace"], outer_start["trace"]) self.assertEqual(inner_start["span"], inner_end["span"]) - self.assertEqual(len(inner_start["span"]), 8) + self.assertEqual(len(inner_start["span"]), 16) self.assertEqual(inner_start["parent_span"], outer_start["span"]) self.assertNotEqual(inner_start["span"], outer_start["span"]) @@ -1019,6 +1022,21 @@ def test_event_context_adds_trace_and_span_fields(self) -> None: self.assertEqual(inner_log["span"], inner_start["span"]) self.assertEqual(inner_log["parent_span"], outer_start["span"]) + def test_trace_context_helpers_generate_w3c_traceparent_headers(self) -> None: + root = src.new_trace_context() + child = root.child() + + self.assertEqual(len(root.trace_id), 32) + self.assertEqual(len(root.span_id), 16) + self.assertEqual(child.trace_id, root.trace_id) + self.assertEqual(child.parent_span_id, root.span_id) + self.assertRegex(root.traceparent(), r"^00-[0-9a-f]{32}-[0-9a-f]{16}-01$") + self.assertEqual(src.trace_context_from_traceparent(root.traceparent()), root) + + with src.trace_context(root): + self.assertEqual(src.current_trace_context(), root) + self.assertEqual(src.traceparent_header(), root.traceparent()) + def test_event_can_lower_start_level_and_omit_success_status(self) -> None: with tempfile.TemporaryDirectory() as directory: log_file = Path(directory) / "events.json" @@ -1249,6 +1267,22 @@ def handler(request: httpx.Request) -> httpx.Response: self.assertEqual(json.loads(seen["body"]), {"hello": "world"}) self.assertEqual(client.max_connections, 7) + def test_json_response_returns_payload_and_response_metadata(self) -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={"ok": True}, + headers={"X-Trace": "a" * 32}, + ) + + client = HTTPClient(max_attempts=1, transport=httpx.MockTransport(handler)) + payload, response = client.json_response("GET", "https://example.com/api") + + self.assertEqual(payload, {"ok": True}) + self.assertIsInstance(response, HTTPResponse) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.header("X-Trace"), "a" * 32) + def test_json_request_emits_structured_http_event(self) -> None: def handler(_request: httpx.Request) -> httpx.Response: return httpx.Response( @@ -1335,6 +1369,19 @@ def handler(_request: httpx.Request) -> httpx.Response: class ClientTest(unittest.TestCase): + def test_sourcegraph_node_ids_convert_between_graphql_and_database_ids(self) -> None: + self.assertEqual(42, decode_external_service_id("RXh0ZXJuYWxTZXJ2aWNlOjQy")) + self.assertEqual("UmVwb3NpdG9yeTo5OQ==", encode_repository_id(99)) + self.assertEqual(99, decode_repository_id("UmVwb3NpdG9yeTo5OQ==")) + self.assertEqual(99, src.decode_repository_id("UmVwb3NpdG9yeTo5OQ==")) + + with self.assertRaisesRegex(ValueError, "not a valid base64"): + decode_repository_id("not base64") + with self.assertRaisesRegex(ValueError, "not a Repository Node ID"): + decode_repository_id("RXh0ZXJuYWxTZXJ2aWNlOjQy") + with self.assertRaisesRegex(ValueError, "non-integer suffix"): + decode_external_service_id("RXh0ZXJuYWxTZXJ2aWNlOmFiYw==") + def test_normalize_sourcegraph_endpoint(self) -> None: self.assertEqual( normalize_sourcegraph_endpoint(" https://sourcegraph.example.com/ "), @@ -1412,6 +1459,82 @@ def test_sourcegraph_client_validate_queries_current_user(self) -> None: self.assertIn("SourcegraphClientValidate", str(body.get("query") or "")) self.assertIn("currentUser", str(body.get("query") or "")) + def test_sourcegraph_trace_mode_records_and_streams_jaeger_summary(self) -> None: + trace_id = "1" * 32 + span_id = "2" * 16 + requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + if request.url.path == "/.api/graphql": + return httpx.Response( + 200, + json={"data": {"currentUser": {"username": "alice"}}}, + headers={ + "X-Trace": trace_id, + "X-Trace-Span": span_id, + "X-Trace-URL": f"https://jaeger.example.com/trace/{trace_id}", + }, + ) + self.assertEqual(request.url.path, f"/-/debug/jaeger/api/traces/{trace_id}") + return httpx.Response( + 200, + json={ + "data": [ + { + "spans": [ + { + "operationName": "GraphQL request", + "duration": 120_000, + "tags": [{"key": "graphql.operationName", "value": "Viewer"}], + }, + { + "operationName": "repo lookup", + "duration": 30_000, + "tags": [ + {"key": "error", "value": True}, + {"key": "otel.status_description", "value": "boom"}, + ], + }, + ] + } + ] + }, + ) + + client = SourcegraphClient( + "https://sourcegraph.example.com/", + "token", + http=HTTPClient(max_attempts=1, transport=httpx.MockTransport(handler)), + trace=True, + ) + root_context = src.TraceContext(trace_id="3" * 32, span_id="4" * 16) + + with src.trace_context(root_context): + self.assertEqual( + client.graphql("query Viewer { currentUser { username } }"), + {"currentUser": {"username": "alice"}}, + ) + traces = client.drain_traces() + summaries = list(client.stream_jaeger_trace_summaries(traces, retry_delays_seconds=(0,))) + + self.assertEqual(len(requests), 2) + traceparent = requests[0].headers["traceparent"] + traceparent_parts = traceparent.split("-") + self.assertEqual(requests[0].headers["x-sourcegraph-request-trace"], "true") + self.assertRegex(traceparent, r"^00-[0-9a-f]{32}-[0-9a-f]{16}-01$") + self.assertEqual(traceparent_parts[1], root_context.trace_id) + self.assertEqual(traces[0].trace_id, trace_id) + self.assertEqual(traces[0].span_id, span_id) + self.assertEqual(traces[0].parent_trace_id, root_context.trace_id) + self.assertEqual(traces[0].parent_span_id, traceparent_parts[2]) + self.assertEqual(len(summaries), 1) + self.assertTrue(summaries[0].jaeger_found) + self.assertEqual(summaries[0].span_count, 2) + self.assertEqual(summaries[0].hot_operations[0]["operation"], "GraphQL request") + self.assertEqual(summaries[0].graphql_operations[0]["operation"], "Viewer") + self.assertEqual(summaries[0].errored_spans[0]["description"], "boom") + def test_graphql_client_paginates_cursor_results(self) -> None: http = RecordingHTTP( [