diff --git a/dev/TODO.md b/dev/TODO.md index 7fe3c6a..afe2acd 100644 --- a/dev/TODO.md +++ b/dev/TODO.md @@ -1,10 +1,5 @@ # TODO -## High priority: Bump src-py-lib after Node ID helper release - -- After releasing `src-py-lib` with Sourcegraph Node ID helpers, update - `pyproject.toml` and `uv.lock` to depend on that new version. - ## Medium priority: Lightweight incremental updates - When a new user's account is created, or a new repo is synced from a code host, @@ -69,6 +64,13 @@ If/when we revisit: 3. Add a CLI flag (e.g. `--cross-check-capture`) gated behind a clear "this doubles capture cost" warning. +## Low priority: Grouped full-set plan if memory is still too high + +Phase 1 now avoids per-repo username sets for non-overlapping full-set maps. +If memory remains too high after re-measuring, implement the Phase 2 grouped +plan in [mapping-efficiency.md](./mapping-efficiency.md): combine map-entry +overlays into final groups of repos that share the same desired username tuple. + ## Low priority: Expand group-membership filters beyond SAML `allowGroups`-style enforcement exists on more than just SAML, but only diff --git a/dev/analyze-memory.py b/dev/analyze-memory.py new file mode 100755 index 0000000..f3f062e --- /dev/null +++ b/dev/analyze-memory.py @@ -0,0 +1,618 @@ +#!/usr/bin/env python3 +"""Fit a Sourcegraph permissions memory model from e2e result JSON. + +The model is intentionally small and dependency-free: + + peak RSS MiB = intercept + users*b1 + repos*b2 + grants*b3 + +Use one command mode per fit. Mixing backup, no-backup, get, set, and restore +runs makes the per-grant coefficient much less useful. +""" + +from __future__ import annotations + +import argparse +import json +import math +import re +import statistics +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, cast + +FEATURE_NAMES = ("users", "repos", "grants") +COEFFICIENT_SCALE = { + "users": "bytes/user", + "repos": "bytes/repo", + "grants": "bytes/grant", +} + + +@dataclass(frozen=True) +class WorkloadDimensions: + """Canonical workload dimensions used by the memory model.""" + + users: float | None + repos: float | None + grants: float | None + + +@dataclass(frozen=True) +class MemoryObservation: + """One e2e command result with peak memory and workload dimensions.""" + + source_path: str + variant: str + case_name: str + command: str + iteration: int + peak_resident_megabytes: float + dimensions: WorkloadDimensions + + +@dataclass(frozen=True) +class MemoryModel: + """Fitted linear memory model.""" + + feature_names: tuple[str, ...] + coefficients_megabytes: dict[str, float] + observation_count: int + r_squared: float | None + mean_absolute_error_megabytes: float + p95_absolute_error_megabytes: float + max_absolute_error_megabytes: float + + +@dataclass(frozen=True) +class MemoryEstimate: + """Predicted memory for a proposed users x repos workload.""" + + dimensions: WorkloadDimensions + peak_resident_megabytes: float + peak_resident_megabytes_with_headroom: float + headroom_percent: float + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Fit a fixed + users + repos + grants memory model from e2e JSON.", + ) + parser.add_argument( + "results_json", + nargs="+", + type=Path, + help="One or more JSON files written by dev/test-end-to-end.py --results-json.", + ) + parser.add_argument( + "--variant", + help="Only include one variant, e.g. candidate or baseline.", + ) + parser.add_argument( + "--command", + help="Only include one structured command, e.g. set_full or get.", + ) + parser.add_argument( + "--case-regex", + help="Only include cases whose e2e case name matches this regular expression.", + ) + parser.add_argument( + "--features", + default="users,repos,grants", + help="Comma-separated model features from users,repos,grants (default: all).", + ) + parser.add_argument( + "--min-grants", + type=float, + default=1.0, + help="Drop observations below this grant count (default: 1).", + ) + parser.add_argument( + "--estimate-users", + type=float, + help="Estimate memory for this many users.", + ) + parser.add_argument( + "--estimate-repos", + type=float, + help="Estimate memory for this many repos.", + ) + parser.add_argument( + "--estimate-grants", + type=float, + help="Estimate memory for this many grants; defaults to users * repos.", + ) + parser.add_argument( + "--headroom-percent", + type=float, + default=30.0, + help="Headroom to add to estimates (default: 30).", + ) + parser.add_argument( + "--json", + action="store_true", + help="Write machine-readable JSON instead of a text report.", + ) + arguments = parser.parse_args() + + feature_names = parse_feature_names(arguments.features) + observations = load_observations(arguments.results_json) + filtered_observations = filter_observations( + observations, + variant=arguments.variant, + command=arguments.command, + case_regex=arguments.case_regex, + min_grants=arguments.min_grants, + ) + model_observations = observations_with_features(filtered_observations, feature_names) + minimum_observations = len(feature_names) + 1 + if len(model_observations) < minimum_observations: + print( + "Need at least " + f"{minimum_observations} observations with {', '.join(feature_names)} " + f"to fit this model; found {len(model_observations)}.", + file=sys.stderr, + ) + return 2 + + try: + model = fit_memory_model(model_observations, feature_names) + except ValueError as error: + print(f"Could not fit memory model: {error}", file=sys.stderr) + print( + "Try filtering to one command mode, adding varied users x repos shapes, " + "or using fewer --features.", + file=sys.stderr, + ) + return 2 + + estimate = build_estimate( + model, + feature_names, + estimate_users=arguments.estimate_users, + estimate_repos=arguments.estimate_repos, + estimate_grants=arguments.estimate_grants, + headroom_percent=arguments.headroom_percent, + ) + if arguments.json: + write_json_report(model, model_observations, estimate) + else: + write_text_report(model, model_observations, estimate) + return 0 + + +def parse_feature_names(raw_features: str) -> tuple[str, ...]: + names = tuple(name.strip() for name in raw_features.split(",") if name.strip()) + invalid = sorted(set(names) - set(FEATURE_NAMES)) + if invalid: + raise SystemExit(f"Unknown feature(s): {', '.join(invalid)}") + duplicates = sorted({name for name in names if names.count(name) > 1}) + if duplicates: + raise SystemExit(f"Duplicate feature(s): {', '.join(duplicates)}") + if not names: + raise SystemExit("At least one feature is required.") + return names + + +def load_observations(paths: list[Path]) -> list[MemoryObservation]: + observations: list[MemoryObservation] = [] + for path in paths: + with path.open(encoding="utf-8") as input_file: + payload: object = json.load(input_file) + for result in result_mappings(payload): + observation = observation_from_result(path, result) + if observation is not None: + observations.append(observation) + return observations + + +def result_mappings(payload: object) -> list[dict[str, Any]]: + if isinstance(payload, dict): + mapping = cast(dict[str, Any], payload) + results = mapping.get("results") + if isinstance(results, list): + return mapping_items(cast(list[object], results)) + if "memory" in mapping and "workload" in mapping: + return [mapping] + if isinstance(payload, list): + return mapping_items(cast(list[object], payload)) + return [] + + +def mapping_items(values: list[object]) -> list[dict[str, Any]]: + """Return only dict-like JSON objects from a JSON list.""" + return [cast(dict[str, Any], value) for value in values if isinstance(value, dict)] + + +def observation_from_result(path: Path, result: dict[str, Any]) -> MemoryObservation | None: + memory = object_mapping(result.get("memory")) + workload = object_mapping(result.get("workload")) + if memory is None or workload is None: + return None + peak_resident_megabytes = first_number(memory, ("peak_rss_mb", "external_peak_rss_mb")) + if peak_resident_megabytes is None: + return None + return MemoryObservation( + source_path=str(path), + variant=string_value(result.get("variant")), + case_name=string_value(result.get("case")), + command=string_value(result.get("command")), + iteration=integer_value(result.get("iteration")), + peak_resident_megabytes=peak_resident_megabytes, + dimensions=WorkloadDimensions( + users=first_number( + workload, + ( + "memory_model_user_count", + "selected_user_count", + "captured_user_count", + "snapshot_user_count_max", + "user_count", + "total_users_scanned", + "sourcegraph_user_count", + "total_users", + ), + ), + repos=first_number( + workload, + ( + "memory_model_repo_count", + "planned_repo_count", + "restore_snapshot_repo_count", + "snapshot_repos_with_explicit_grants_max", + "repos_with_explicit_grants", + "loaded_repo_count", + "repo_count", + ), + ), + grants=first_number( + workload, + ( + "memory_model_grant_count", + "planned_total_grants", + "restore_snapshot_total_grants", + "selected_total_grants", + "snapshot_total_grants_max", + "total_grants", + "apply_payload_grant_count", + ), + ), + ), + ) + + +def filter_observations( + observations: list[MemoryObservation], + *, + variant: str | None, + command: str | None, + case_regex: str | None, + min_grants: float, +) -> list[MemoryObservation]: + pattern = re.compile(case_regex) if case_regex else None + filtered: list[MemoryObservation] = [] + for observation in observations: + if variant is not None and observation.variant != variant: + continue + if command is not None and observation.command != command: + continue + if pattern is not None and pattern.search(observation.case_name) is None: + continue + if observation.dimensions.grants is None or observation.dimensions.grants < min_grants: + continue + filtered.append(observation) + return filtered + + +def observations_with_features( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> list[MemoryObservation]: + return [ + observation + for observation in observations + if all(feature_value(observation.dimensions, name) is not None for name in feature_names) + ] + + +def fit_memory_model( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> MemoryModel: + feature_scales = feature_scale_by_name(observations, feature_names) + matrix = [ + [1.0] + + [ + required_feature_value(observation.dimensions, feature_name) + / feature_scales[feature_name] + for feature_name in feature_names + ] + for observation in observations + ] + targets = [observation.peak_resident_megabytes for observation in observations] + scaled_coefficients = solve_normal_equations(matrix, targets) + coefficients = {"intercept": scaled_coefficients[0]} + for feature_index, feature_name in enumerate(feature_names, start=1): + coefficients[feature_name] = ( + scaled_coefficients[feature_index] / feature_scales[feature_name] + ) + predictions = [ + predict_megabytes(coefficients, observation.dimensions) for observation in observations + ] + residuals = [ + target - prediction for target, prediction in zip(targets, predictions, strict=True) + ] + absolute_residuals = [abs(residual) for residual in residuals] + target_mean = statistics.fmean(targets) + residual_sum_squares = sum(residual * residual for residual in residuals) + total_sum_squares = sum((target - target_mean) ** 2 for target in targets) + return MemoryModel( + feature_names=feature_names, + coefficients_megabytes=coefficients, + observation_count=len(observations), + r_squared=( + None if total_sum_squares == 0 else 1.0 - residual_sum_squares / total_sum_squares + ), + mean_absolute_error_megabytes=statistics.fmean(absolute_residuals), + p95_absolute_error_megabytes=percentile(absolute_residuals, 95.0), + max_absolute_error_megabytes=max(absolute_residuals), + ) + + +def feature_scale_by_name( + observations: list[MemoryObservation], feature_names: tuple[str, ...] +) -> dict[str, float]: + scales: dict[str, float] = {} + for feature_name in feature_names: + maximum = max( + abs(required_feature_value(observation.dimensions, feature_name)) + for observation in observations + ) + scales[feature_name] = maximum if maximum > 0 else 1.0 + return scales + + +def solve_normal_equations(matrix: list[list[float]], targets: list[float]) -> list[float]: + column_count = len(matrix[0]) + normal_matrix = [[0.0 for _ in range(column_count)] for _ in range(column_count)] + normal_targets = [0.0 for _ in range(column_count)] + for row, target in zip(matrix, targets, strict=True): + for row_index in range(column_count): + normal_targets[row_index] += row[row_index] * target + for column_index in range(column_count): + normal_matrix[row_index][column_index] += row[row_index] * row[column_index] + return solve_linear_system(normal_matrix, normal_targets) + + +def solve_linear_system(matrix: list[list[float]], values: list[float]) -> list[float]: + size = len(values) + augmented = [matrix[row_index][:] + [values[row_index]] for row_index in range(size)] + for pivot_index in range(size): + pivot_row = max( + range(pivot_index, size), + key=lambda row_index: abs(augmented[row_index][pivot_index]), + ) + pivot_value = augmented[pivot_row][pivot_index] + if abs(pivot_value) < 1e-12: + raise ValueError("features are collinear or the sample is too small") + augmented[pivot_index], augmented[pivot_row] = augmented[pivot_row], augmented[pivot_index] + for column_index in range(pivot_index, size + 1): + augmented[pivot_index][column_index] /= pivot_value + for row_index in range(size): + if row_index == pivot_index: + continue + factor = augmented[row_index][pivot_index] + for column_index in range(pivot_index, size + 1): + augmented[row_index][column_index] -= factor * augmented[pivot_index][column_index] + return [augmented[row_index][size] for row_index in range(size)] + + +def build_estimate( + model: MemoryModel, + feature_names: tuple[str, ...], + *, + estimate_users: float | None, + estimate_repos: float | None, + estimate_grants: float | None, + headroom_percent: float, +) -> MemoryEstimate | None: + if estimate_users is None and estimate_repos is None and estimate_grants is None: + return None + if "users" in feature_names and estimate_users is None: + raise SystemExit("--estimate-users is required because users is in --features.") + if "repos" in feature_names and estimate_repos is None: + raise SystemExit("--estimate-repos is required because repos is in --features.") + if "grants" in feature_names and estimate_grants is None: + if estimate_users is None or estimate_repos is None: + raise SystemExit( + "--estimate-grants is required unless --estimate-users and --estimate-repos " + "are both set." + ) + estimate_grants = estimate_users * estimate_repos + dimensions = WorkloadDimensions( + users=estimate_users, + repos=estimate_repos, + grants=estimate_grants, + ) + peak_resident_megabytes = predict_megabytes(model.coefficients_megabytes, dimensions) + return MemoryEstimate( + dimensions=dimensions, + peak_resident_megabytes=peak_resident_megabytes, + peak_resident_megabytes_with_headroom=peak_resident_megabytes + * (1.0 + headroom_percent / 100.0), + headroom_percent=headroom_percent, + ) + + +def predict_megabytes( + coefficients_megabytes: dict[str, float], dimensions: WorkloadDimensions +) -> float: + prediction = coefficients_megabytes["intercept"] + for feature_name in FEATURE_NAMES: + coefficient = coefficients_megabytes.get(feature_name) + value = feature_value(dimensions, feature_name) + if coefficient is not None and value is not None: + prediction += coefficient * value + return prediction + + +def write_text_report( + model: MemoryModel, observations: list[MemoryObservation], estimate: MemoryEstimate | None +) -> None: + print(f"Observations used: {model.observation_count}") + print(f"Features: {', '.join(model.feature_names)}") + print("\nCoefficients:") + print(f" intercept: {model.coefficients_megabytes['intercept']:.3f} MiB") + for feature_name in model.feature_names: + coefficient_megabytes = model.coefficients_megabytes[feature_name] + coefficient_bytes = coefficient_megabytes * 1024.0 * 1024.0 + print( + f" {feature_name}: {coefficient_megabytes:.9f} MiB/unit " + f"({coefficient_bytes:.1f} {COEFFICIENT_SCALE[feature_name]})" + ) + r_squared = "n/a" if model.r_squared is None else f"{model.r_squared:.4f}" + print("\nFit quality:") + print(f" R²: {r_squared}") + print(f" mean absolute error: {model.mean_absolute_error_megabytes:.2f} MiB") + print(f" p95 absolute error: {model.p95_absolute_error_megabytes:.2f} MiB") + print(f" max absolute error: {model.max_absolute_error_megabytes:.2f} MiB") + print("\nObserved range:") + print_dimension_range(observations, "users") + print_dimension_range(observations, "repos") + print_dimension_range(observations, "grants") + if estimate is not None: + print("\nEstimate:") + print(f" users: {format_optional_number(estimate.dimensions.users)}") + print(f" repos: {format_optional_number(estimate.dimensions.repos)}") + print(f" grants: {format_optional_number(estimate.dimensions.grants)}") + print(f" peak RSS: {estimate.peak_resident_megabytes:.1f} MiB") + print( + f" with {estimate.headroom_percent:g}% headroom: " + f"{estimate.peak_resident_megabytes_with_headroom:.1f} MiB" + ) + + +def write_json_report( + model: MemoryModel, observations: list[MemoryObservation], estimate: MemoryEstimate | None +) -> None: + report: dict[str, Any] = { + "observation_count": model.observation_count, + "features": list(model.feature_names), + "coefficients_mib": model.coefficients_megabytes, + "coefficients_bytes": { + feature_name: model.coefficients_megabytes[feature_name] * 1024.0 * 1024.0 + for feature_name in model.feature_names + }, + "fit": { + "r_squared": model.r_squared, + "mean_absolute_error_mib": model.mean_absolute_error_megabytes, + "p95_absolute_error_mib": model.p95_absolute_error_megabytes, + "max_absolute_error_mib": model.max_absolute_error_megabytes, + }, + "observed_range": observed_range_to_json(observations), + "estimate": estimate_to_json(estimate), + } + json.dump(report, sys.stdout, indent=2, sort_keys=True) + sys.stdout.write("\n") + + +def print_dimension_range(observations: list[MemoryObservation], feature_name: str) -> None: + values = [ + value + for observation in observations + if (value := feature_value(observation.dimensions, feature_name)) is not None + ] + if not values: + print(f" {feature_name}: n/a") + return + print(f" {feature_name}: {format_number(min(values))} .. {format_number(max(values))}") + + +def observed_range_to_json(observations: list[MemoryObservation]) -> dict[str, dict[str, float]]: + ranges: dict[str, dict[str, float]] = {} + for feature_name in FEATURE_NAMES: + values = [ + value + for observation in observations + if (value := feature_value(observation.dimensions, feature_name)) is not None + ] + if values: + ranges[feature_name] = {"min": min(values), "max": max(values)} + return ranges + + +def estimate_to_json(estimate: MemoryEstimate | None) -> dict[str, Any] | None: + if estimate is None: + return None + return { + "users": estimate.dimensions.users, + "repos": estimate.dimensions.repos, + "grants": estimate.dimensions.grants, + "peak_rss_mib": estimate.peak_resident_megabytes, + "headroom_percent": estimate.headroom_percent, + "peak_rss_mib_with_headroom": estimate.peak_resident_megabytes_with_headroom, + } + + +def object_mapping(value: object) -> dict[str, Any] | None: + return cast(dict[str, Any], value) if isinstance(value, dict) else None + + +def first_number(mapping: dict[str, Any], names: tuple[str, ...]) -> float | None: + for name in names: + value = mapping.get(name) + if isinstance(value, bool): + continue + if isinstance(value, int | float): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + continue + return None + + +def string_value(value: object) -> str: + return value if isinstance(value, str) else "" + + +def integer_value(value: object) -> int: + if isinstance(value, bool): + return 0 + return value if isinstance(value, int) else 0 + + +def feature_value(dimensions: WorkloadDimensions, feature_name: str) -> float | None: + if feature_name == "users": + return dimensions.users + if feature_name == "repos": + return dimensions.repos + if feature_name == "grants": + return dimensions.grants + raise ValueError(f"Unknown feature: {feature_name}") + + +def required_feature_value(dimensions: WorkloadDimensions, feature_name: str) -> float: + value = feature_value(dimensions, feature_name) + if value is None: + raise ValueError(f"Observation is missing feature: {feature_name}") + return value + + +def percentile(values: list[float], percentile_value: float) -> float: + if not values: + return math.nan + sorted_values = sorted(values) + index = math.ceil((percentile_value / 100.0) * len(sorted_values)) - 1 + return sorted_values[min(max(index, 0), len(sorted_values) - 1)] + + +def format_optional_number(value: float | None) -> str: + return "n/a" if value is None else format_number(value) + + +def format_number(value: float) -> str: + return f"{value:.0f}" if value.is_integer() else f"{value:.3f}" + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/dev/mapping-efficiency.md b/dev/mapping-efficiency.md new file mode 100644 index 0000000..f8ccb51 --- /dev/null +++ b/dev/mapping-efficiency.md @@ -0,0 +1,140 @@ +# Mapping efficiency + +## Rectangular maps example + +Input maps + +```yaml +maps: + - name: engineers get generated repos + users: + usernames: + - alice + - bob + - carol + repos: + names: + - repo-1 + - repo-2 + - repo-3 +``` + +Current: Repo-centric plan + +repo-1 -> (alice, bob, carol) +repo-2 -> (alice, bob, carol) +repo-3 -> (alice, bob, carol) + +Grouped plan + +(alice, bob, carol) -> (repo-1, repo-2, repo-3) + +## Current semantics + +Each `maps:` entry is naturally a grouped rule: + +```text +selected users × selected repos +``` + +The full-set command must combine all entries before mutating Sourcegraph, +because `setRepositoryPermissionsForUsers` overwrites a repo's whole explicit +permission list. The required final state is: + +```text +desired_users(repo) = union(users_i for each map_i where repo is in repos_i) +``` + +Only after this union is known can the command safely apply per-repo overwrite +mutations. + +## Phase 1: lazy per-repo union sets + +The old full-set planner immediately expanded every map entry into: + +```text +repo_id -> set(username) +``` + +That is expensive for rectangular maps such as `10000 users × 1000 repos`: +the username strings are shared, but each repo owns a large Python set with one +hash-table entry per planned grant. + +Phase 1 keeps the existing downstream plan shape: + +```text +repo_id -> tuple(username) +``` + +but builds it more carefully: + +1. For a non-overlapping map entry, create one sorted username tuple and reuse + that same tuple for every matched repo. +2. If a later map entry touches a repo that already has users, promote only + that repo to a temporary set and union the usernames. +3. Convert only promoted repos back to sorted tuples after all map entries are + processed. + +This preserves the hard invariant while avoiding the large per-repo sets in +the common non-overlapping rectangular case. + +Measured on the sgdev test instance, the dry-run `10000x1000` case planned 10M +grants. Before Phase 1 it peaked at about 651 MiB RSS; after Phase 1 it peaked +at about 68 MiB RSS. + +## Phase 2: final grouped plan, if needed + +If Phase 1 is not enough, store the combined final plan as groups of repos that +share the same final user set: + +```text +tuple(username) -> tuple(repo_id) +``` + +This is not just one group per `maps:` entry. Map entries are input overlays; +final groups are the compressed result after every map entry has been unioned +onto the repo space. + +Example: + +```text +map A: alice,bob -> repo-1,repo-2 +map B: bob,chris -> repo-2,repo-3 + +final: +alice,bob -> repo-1 +alice,bob,chris -> repo-2 +bob,chris -> repo-3 +``` + +One practical data model would be: + +```python +@dataclass(frozen=True) +class RepositoryPermissionGroup: + usernames: tuple[str, ...] + repository_ids: tuple[str, ...] + + +@dataclass(frozen=True) +class FullSetPlan: + groups: tuple[RepositoryPermissionGroup, ...] + repo_names: dict[str, str] + repo_to_group_index: dict[str, int] + + def usernames_for_repo(self, repo_id: str) -> tuple[str, ...]: + return self.groups[self.repo_to_group_index[repo_id]].usernames +``` + +Apply still happens per repo: + +```text +for group in groups: + for repo_id in group.repository_ids: + setRepositoryPermissionsForUsers(repo_id, group.usernames) +``` + +Phase 2 touches more code than Phase 1: projected snapshots, diffs, +short-circuit filtering, apply iteration, and validation all currently expect +direct `repo_id -> usernames` lookups. Do it only if Phase 1 measurements still +show unacceptable memory use. diff --git a/dev/monitor-sourcegraph-load.sh b/dev/monitor-sourcegraph-load.sh new file mode 100755 index 0000000..584529c --- /dev/null +++ b/dev/monitor-sourcegraph-load.sh @@ -0,0 +1,348 @@ +#!/usr/bin/env bash +set -euo pipefail + +namespace="${SRC_AUTH_PERMS_SYNC_MONITOR_NAMESPACE:-m}" +interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_INTERVAL_SECONDS:-5}" +postgres_interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_POSTGRES_INTERVAL_SECONDS:-10}" +statements_interval_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_STATEMENTS_INTERVAL_SECONDS:-30}" +duration_seconds="${SRC_AUTH_PERMS_SYNC_MONITOR_DURATION_SECONDS:-}" +output_dir="${SRC_AUTH_PERMS_SYNC_MONITOR_OUTPUT_DIR:-}" +frontend_target="${SRC_AUTH_PERMS_SYNC_MONITOR_FRONTEND_TARGET:-deployment/sourcegraph-frontend}" +postgres_target="${SRC_AUTH_PERMS_SYNC_MONITOR_POSTGRES_TARGET:-pod/pgsql-0}" +kubectl_bin="${KUBECTL:-kubectl}" +psql_command="${SRC_AUTH_PERMS_SYNC_MONITOR_PSQL_COMMAND:-psql -X -U sg -d sg}" +stream_logs=true + +usage() { + cat <<'EOF' +Usage: dev/monitor-sourcegraph-load.sh [options] + +Collect timestamped Sourcegraph pod load evidence while the e2e script runs. +Press Ctrl-C to stop, or pass --duration-seconds. + +Options: + --namespace NAME Kubernetes namespace (default: m) + --interval-seconds N Pod/process/cgroup sample interval (default: 5) + --postgres-interval-seconds N pg_stat_activity sample interval (default: 10) + --statements-interval-seconds N pg_stat_statements sample interval (default: 30) + --duration-seconds N Stop automatically after N seconds + --output-dir PATH Output directory (default: /tmp/src-auth-perms-sync-sourcegraph-load-) + --frontend-target TARGET kubectl target for frontend (default: deployment/sourcegraph-frontend) + --postgres-target TARGET kubectl target for Postgres (default: pod/pgsql-0) + --psql-command COMMAND Command to run inside Postgres pod (default: psql -X -U sg -d sg) + --no-logs Do not stream frontend logs + -h, --help Show this help + +Examples: + dev/monitor-sourcegraph-load.sh + + dev/monitor-sourcegraph-load.sh \ + --duration-seconds 1800 \ + --output-dir /tmp/src-auth-perms-sync-load-$(date -u +%Y%m%d-%H%M%S) + +In another terminal, run: + uv run python dev/test-end-to-end.py --trace --sample-interval 0 --external-sample-interval 0 +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --namespace) + namespace="$2" + shift 2 + ;; + --interval-seconds) + interval_seconds="$2" + shift 2 + ;; + --postgres-interval-seconds) + postgres_interval_seconds="$2" + shift 2 + ;; + --statements-interval-seconds) + statements_interval_seconds="$2" + shift 2 + ;; + --duration-seconds) + duration_seconds="$2" + shift 2 + ;; + --output-dir) + output_dir="$2" + shift 2 + ;; + --frontend-target) + frontend_target="$2" + shift 2 + ;; + --postgres-target) + postgres_target="$2" + shift 2 + ;; + --psql-command) + psql_command="$2" + shift 2 + ;; + --no-logs) + stream_logs=false + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage >&2 + exit 2 + ;; + esac +done + +if [[ -z "${output_dir}" ]]; then + output_dir="/tmp/src-auth-perms-sync-sourcegraph-load-$(date -u +%Y%m%d-%H%M%S)" +fi +mkdir -p "${output_dir}" + +end_epoch="" +if [[ -n "${duration_seconds}" ]]; then + end_epoch="$(( $(date +%s) + duration_seconds ))" +fi + +pids=() + +timestamp() { + date -u +%Y-%m-%dT%H:%M:%SZ +} + +should_continue() { + [[ -z "${end_epoch}" || "$(date +%s)" -lt "${end_epoch}" ]] +} + +append_header() { + local file="$1" + local title="$2" + { + printf '\n===== %s %s =====\n' "$(timestamp)" "${title}" + } >>"${file}" +} + +run_sample_loop() { + local name="$1" + local sleep_seconds="$2" + local pid + shift 2 + ( + while should_continue; do + "$@" || true + sleep "${sleep_seconds}" + done + ) & + pid="$!" + pids+=("${pid}") + echo "Started ${name} sampler: pid=${pid} interval=${sleep_seconds}s" +} + +run_stream() { + local name="$1" + local pid + shift + ( + "$@" || true + ) & + pid="$!" + pids+=("${pid}") + echo "Started ${name} stream: pid=${pid}" +} + +cleanup() { + local status=$? + trap - EXIT INT TERM + if [[ ${#pids[@]} -gt 0 ]]; then + kill "${pids[@]}" 2>/dev/null || true + wait "${pids[@]}" 2>/dev/null || true + fi + echo "Stopped Sourcegraph load monitor. Output: ${output_dir}" + exit "${status}" +} + +trap cleanup EXIT INT TERM + +kubectl_exec() { + local target="$1" + shift + "${kubectl_bin}" exec -n "${namespace}" "${target}" -- "$@" +} + +kubectl_exec_stdin() { + local target="$1" + shift + "${kubectl_bin}" exec -i -n "${namespace}" "${target}" -- "$@" +} + +prepare_pg_stat_statements() { + local file="${output_dir}/postgres-statements-setup.log" + append_header "${file}" "create pg_stat_statements extension and reset stats" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select current_database(), current_user; +show shared_preload_libraries; +show track_io_timing; +create extension if not exists pg_stat_statements; +select pg_stat_statements_reset(); +SQL +} + +sample_kubectl_top() { + local file="${output_dir}/kubectl-top-pods-containers.log" + append_header "${file}" "kubectl top pods --containers" + "${kubectl_bin}" top pods -n "${namespace}" --containers >>"${file}" 2>&1 || true +} + +sample_frontend_processes() { + local file="${output_dir}/frontend-processes.log" + append_header "${file}" "${frontend_target} process CPU/RSS" + kubectl_exec "${frontend_target}" sh -lc ' + echo "--- top CPU ---" + ps auxww | sort -nrk3 | head -30 + echo "--- top RSS ---" + ps auxww | sort -nrk4 | head -30 + ' >>"${file}" 2>&1 || true +} + +sample_postgres_processes() { + local file="${output_dir}/postgres-processes.log" + append_header "${file}" "${postgres_target} process CPU/RSS" + kubectl_exec "${postgres_target}" sh -lc ' + echo "--- top CPU ---" + ps auxww | sort -nrk3 | head -30 + echo "--- top RSS ---" + ps auxww | sort -nrk4 | head -30 + ' >>"${file}" 2>&1 || true +} + +sample_cgroups() { + local file="${output_dir}/cgroups.log" + append_header "${file}" "cgroup CPU/memory" + for target in "${frontend_target}" "${postgres_target}"; do + { + echo "--- ${target} ---" + kubectl_exec "${target}" sh -lc ' + echo "cpu.stat" + cat /sys/fs/cgroup/cpu.stat 2>/dev/null || cat /sys/fs/cgroup/cpu/cpu.stat 2>/dev/null || true + echo "memory.current" + cat /sys/fs/cgroup/memory.current 2>/dev/null || cat /sys/fs/cgroup/memory/memory.usage_in_bytes 2>/dev/null || true + echo "memory.events" + cat /sys/fs/cgroup/memory.events 2>/dev/null || true + echo "memory.max" + cat /sys/fs/cgroup/memory.max 2>/dev/null || cat /sys/fs/cgroup/memory/memory.limit_in_bytes 2>/dev/null || true + ' + } >>"${file}" 2>&1 || true + done +} + +sample_postgres_activity() { + local file="${output_dir}/postgres-activity.log" + append_header "${file}" "pg_stat_activity, waits, locks" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select + pid, + now() - query_start as age, + state, + wait_event_type, + wait_event, + left(query, 220) as query +from pg_stat_activity +where state <> 'idle' +order by age desc +limit 30; + +select + wait_event_type, + wait_event, + state, + count(*) +from pg_stat_activity +group by 1,2,3 +order by count(*) desc; + +select + locktype, + mode, + granted, + count(*) +from pg_locks +group by 1,2,3 +order by count(*) desc; +SQL +} + +sample_pg_stat_statements() { + local file="${output_dir}/postgres-statements.log" + append_header "${file}" "pg_stat_statements top total_exec_time" + cat <<'SQL' | kubectl_exec_stdin "${postgres_target}" sh -lc "${psql_command} -P pager=off" >>"${file}" 2>&1 || true +select + calls, + round(total_exec_time::numeric, 1) as total_ms, + round(mean_exec_time::numeric, 1) as mean_ms, + rows, + left(query, 260) as query +from pg_stat_statements +order by total_exec_time desc +limit 25; +SQL +} + +snapshot_pod_descriptions() { + local file="${output_dir}/pod-descriptions.log" + append_header "${file}" "kubectl describe selected targets" + "${kubectl_bin}" describe -n "${namespace}" "${frontend_target}" >>"${file}" 2>&1 || true + "${kubectl_bin}" describe -n "${namespace}" "${postgres_target}" >>"${file}" 2>&1 || true +} + +stream_frontend_logs() { + "${kubectl_bin}" logs -n "${namespace}" "${frontend_target}" --since=1m --timestamps -f \ + >"${output_dir}/frontend.log" 2>"${output_dir}/frontend-log-errors.log" +} + +stream_frontend_error_logs() { + "${kubectl_bin}" logs -n "${namespace}" "${frontend_target}" --since=1m --timestamps -f 2>/dev/null \ + | grep -Ei 'timeout|deadline|database|postgres|graphql|error|slow|cancel' \ + >"${output_dir}/frontend-errors-filtered.log" || true +} + +cat >"${output_dir}/metadata.txt" < int: + return self.users * self.repos + + @property + def name(self) -> str: + return f"u{self.users:05d}-r{self.repos:05d}-g{self.grants:010d}" + + +@dataclass(frozen=True) +class ExternalServiceChoice: + """Code host connection selected for repo sampling.""" + + graphql_id: str + database_id: int + display_name: str + kind: str + url: str + repo_count: int + + +@dataclass(frozen=True) +class GeneratedMap: + """One generated maps.yaml file and its workload dimensions.""" + + case: SweepCase + path: Path + + +@dataclass(frozen=True) +class CommandRunResult: + """One CLI execution result written in analyze-memory.py-compatible shape.""" + + generated_map: GeneratedMap + return_code: int + elapsed_seconds: float + output_path: Path + log_path: Path | None + run_record: dict[str, Any] | None + + +def main() -> int: + parser = build_parser() + arguments = parser.parse_args() + mode = cast(RunMode, arguments.mode) + if mode == "apply-no-backup" and not arguments.allow_apply: + parser.error("--mode apply-no-backup requires --allow-apply") + + config = sourcegraph_config(arguments) + output_dir = arguments.output_dir or default_output_dir(config.src_endpoint) + maps_dir = output_dir / "maps" + output_dir.mkdir(parents=True, exist_ok=True) + maps_dir.mkdir(parents=True, exist_ok=True) + + requested_cases = parse_cases(arguments.cases) + + client = src.SourcegraphClient( + endpoint=config.src_endpoint, + token=config.src_access_token, + http=src.HTTPClient( + timeout=arguments.http_timeout_seconds, + max_connections=max(4, arguments.parallelism), + ), + ) + try: + external_services = list_external_services(client) + inventory_repo_count = sum(service.repo_count for service in external_services) + service = choose_external_service(external_services, arguments.external_service_id) + total_user_count = count_users(client) + cases = requested_cases or default_cases_for_inventory( + total_user_count, + service.repo_count, + ) + max_users = max(sweep_case.users for sweep_case in cases) + max_repos = max(sweep_case.repos for sweep_case in cases) + usernames = list_usernames(client, max_users, arguments.page_size) + repo_names = list_repo_names(client, service, max_repos, arguments.page_size) + finally: + client.http.close() + + generated_maps = write_maps(maps_dir, cases, usernames, repo_names, service) + write_manifest(output_dir, generated_maps, service, config.src_endpoint, inventory_repo_count) + print(f"Generated {len(generated_maps)} maps.yaml file(s) under {maps_dir}") + print( + f"Selected code host: {service.display_name} id={service.database_id} " + f"repos={service.repo_count}; instance repoCount sum={inventory_repo_count}" + ) + + if not arguments.run: + print("Generation only. Re-run with --run to execute the sweep.") + return 0 + + run_results = run_sweep( + generated_maps, + endpoint=config.src_endpoint, + access_token=config.src_access_token, + output_dir=output_dir, + command=arguments.command, + mode=mode, + parallelism=arguments.parallelism, + explicit_permissions_batch_size=arguments.explicit_permissions_batch_size, + http_timeout_seconds=arguments.http_timeout_seconds, + sample_interval=arguments.sample_interval, + trace=arguments.trace, + sourcegraph_user_count=total_user_count, + sourcegraph_inventory_repo_count=inventory_repo_count, + ) + write_results(output_dir, run_results, inventory_repo_count, total_user_count) + return 0 if all(result.return_code == 0 for result in run_results) else 1 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Generate and optionally run maps.yaml memory-model sweep cases.", + ) + parser.add_argument( + "--env-file", + type=Path, + default=Path(".env"), + help="Environment file with SRC_ENDPOINT and SRC_ACCESS_TOKEN (default: .env).", + ) + parser.add_argument("--src-endpoint", help="Override SRC_ENDPOINT for discovery and runs.") + parser.add_argument("--src-access-token", help="Override SRC_ACCESS_TOKEN.") + parser.add_argument( + "--output-dir", + type=Path, + help=( + "Directory for generated maps and result files. " + "Defaults under src-auth-perms-sync-runs/." + ), + ) + parser.add_argument( + "--cases", + default=DEFAULT_CASES, + help=( + "Comma-separated users x repos cases, e.g. '100x10,1000x25', " + "or 'auto' for a gentle inventory-aware sweep. Default: auto." + ), + ) + parser.add_argument( + "--external-service-id", + type=int, + help="Decoded external service DB id to sample repos from. Defaults to largest repoCount.", + ) + parser.add_argument( + "--page-size", + type=int, + default=1000, + help="GraphQL page size for discovery queries (default: 1000).", + ) + parser.add_argument( + "--run", + action="store_true", + help="Run src-auth-perms-sync for each generated maps.yaml file.", + ) + parser.add_argument( + "--mode", + choices=("dry-run", "apply-no-backup"), + default="dry-run", + help="Run mode when --run is set. Default is dry-run.", + ) + parser.add_argument( + "--allow-apply", + action="store_true", + help="Required safety acknowledgement for --mode apply-no-backup.", + ) + parser.add_argument( + "--command", + default=DEFAULT_COMMAND, + help=f"Command used to invoke the CLI (default: {DEFAULT_COMMAND!r}).", + ) + parser.add_argument( + "--parallelism", + type=int, + default=1, + help="CLI --parallelism for sweep runs. Default 1 is gentle on pgsql.", + ) + parser.add_argument( + "--explicit-permissions-batch-size", + type=int, + default=25, + help="CLI --explicit-permissions-batch-size for sweep runs (default: 25).", + ) + parser.add_argument( + "--http-timeout-seconds", + type=float, + default=120.0, + help="HTTP timeout for discovery and CLI runs (default: 120).", + ) + parser.add_argument( + "--sample-interval", + type=float, + default=1.0, + help="CLI --sample-interval for resource samples (default: 1).", + ) + parser.add_argument( + "--trace", + action="store_true", + help="Pass --trace to src-auth-perms-sync sweep runs.", + ) + return parser + + +def sourcegraph_config(arguments: argparse.Namespace) -> SweepSourcegraphConfig: + overrides: dict[str, object] = {} + if arguments.src_endpoint: + overrides["src_endpoint"] = arguments.src_endpoint + if arguments.src_access_token: + overrides["src_access_token"] = arguments.src_access_token + return load_config( + SweepSourcegraphConfig, + env_file=arguments.env_file, + cli_overrides=overrides, + base_dir=Path.cwd(), + resolve_op_refs=True, + require=("src_access_token",), + ) + + +def parse_cases(raw_cases: str) -> list[SweepCase] | None: + if raw_cases.strip().lower() == "auto": + return None + cases: list[SweepCase] = [] + for raw_case in raw_cases.split(","): + case = raw_case.strip().lower() + if not case: + continue + users_text, separator, repos_text = case.partition("x") + if not separator: + raise SystemExit(f"Invalid case {raw_case!r}; expected USERSxREPOS") + try: + users = int(users_text) + repos = int(repos_text) + except ValueError as error: + raise SystemExit(f"Invalid case {raw_case!r}; counts must be integers") from error + if users < 1 or repos < 1: + raise SystemExit(f"Invalid case {raw_case!r}; counts must be >= 1") + cases.append(SweepCase(users=users, repos=repos)) + if not cases: + raise SystemExit("At least one --cases entry is required") + return cases + + +def default_cases_for_inventory(user_count: int, repo_count: int) -> list[SweepCase]: + """Return a safe default sweep that covers user, repo, and grant axes.""" + if user_count < 1: + raise SystemExit("Need at least one Sourcegraph user for an auto sweep") + if repo_count < 1: + raise SystemExit("Need at least one Sourcegraph repo for an auto sweep") + + user_points = bounded_points(user_count, DEFAULT_USER_POINTS) + repo_points = bounded_points(repo_count, DEFAULT_REPO_POINTS) + cases: list[SweepCase] = [SweepCase(users=users, repos=1) for users in user_points] + cases.extend(SweepCase(users=1, repos=repos) for repos in repo_points if repos != 1) + + for users, repos in ( + (1000, 10), + (10000, 10), + (1000, 100), + (100, 1000), + ): + if users <= user_count and repos <= repo_count: + cases.append(SweepCase(users=users, repos=repos)) + + return unique_cases(cases) + + +def bounded_points(available_count: int, candidate_points: Sequence[int]) -> list[int]: + """Return candidate points that fit, plus the exact inventory cap if useful.""" + points = [point for point in candidate_points if point <= available_count] + if available_count not in points and available_count < candidate_points[-1]: + points.append(available_count) + return sorted(set(points)) + + +def unique_cases(cases: Sequence[SweepCase]) -> list[SweepCase]: + """Preserve case order while removing duplicates.""" + seen: set[tuple[int, int]] = set() + unique: list[SweepCase] = [] + for sweep_case in cases: + key = (sweep_case.users, sweep_case.repos) + if key in seen: + continue + seen.add(key) + unique.append(sweep_case) + return unique + + +def list_external_services(client: src.SourcegraphClient) -> list[ExternalServiceChoice]: + services: list[ExternalServiceChoice] = [] + for node in client.stream_connection_nodes( + QUERY_EXTERNAL_SERVICES, + variables={"first": 100, "after": None}, + connection_path=("externalServices",), + page_size=100, + ): + service = cast(dict[str, Any], node) + graphql_id = str(service["id"]) + services.append( + ExternalServiceChoice( + graphql_id=graphql_id, + database_id=src.decode_external_service_id(graphql_id), + display_name=str(service.get("displayName") or ""), + kind=str(service.get("kind") or ""), + url=str(service.get("url") or ""), + repo_count=int(service.get("repoCount") or 0), + ) + ) + if not services: + raise SystemExit("No external services found on the Sourcegraph instance") + return services + + +def choose_external_service( + services: list[ExternalServiceChoice], requested_id: int | None +) -> ExternalServiceChoice: + if requested_id is not None: + for service in services: + if service.database_id == requested_id: + return service + raise SystemExit(f"External service id {requested_id} was not found") + return max(services, key=lambda service: service.repo_count) + + +def list_usernames(client: src.SourcegraphClient, count: int, page_size: int) -> list[str]: + usernames: list[str] = [] + for node in client.stream_connection_nodes( + QUERY_USERNAMES, + connection_path=("users",), + page_size=page_size, + ): + username = node.get("username") + if isinstance(username, str) and username: + usernames.append(username) + if len(usernames) >= count: + break + if len(usernames) < count: + raise SystemExit(f"Need {count} users but discovered only {len(usernames)}") + return usernames + + +def count_users(client: src.SourcegraphClient) -> int: + """Return total users on the Sourcegraph instance.""" + data = client.graphql(QUERY_USER_COUNT) + users = cast(dict[str, Any], data.get("users") or {}) + total_count = users.get("totalCount") + if not isinstance(total_count, int): + raise SystemExit("CountUsers response did not include users.totalCount") + return total_count + + +def list_repo_names( + client: src.SourcegraphClient, + service: ExternalServiceChoice, + count: int, + page_size: int, +) -> list[str]: + repo_names: list[str] = [] + for node in client.stream_connection_nodes( + QUERY_REPOS_BY_EXTERNAL_SERVICE, + variables={"externalService": service.graphql_id}, + connection_path=("repositories",), + page_size=page_size, + ): + name = node.get("name") + if isinstance(name, str) and name: + repo_names.append(name) + if len(repo_names) >= count: + break + if len(repo_names) < count: + raise SystemExit( + f"Need {count} repos from external service id={service.database_id} " + f"but discovered only {len(repo_names)}" + ) + return repo_names + + +def write_maps( + maps_dir: Path, + cases: Sequence[SweepCase], + usernames: Sequence[str], + repo_names: Sequence[str], + service: ExternalServiceChoice, +) -> list[GeneratedMap]: + generated: list[GeneratedMap] = [] + for sweep_case in cases: + map_path = maps_dir / f"maps-{sweep_case.name}.yaml" + payload = { + "maps": [ + { + "name": ( + "memory model " + f"users={sweep_case.users} repos={sweep_case.repos} " + f"grants={sweep_case.grants}" + ), + "users": {"usernames": list(usernames[: sweep_case.users])}, + "repos": { + "codeHostConnection": {"id": service.database_id}, + "names": list(repo_names[: sweep_case.repos]), + }, + } + ] + } + with map_path.open("w", encoding="utf-8") as output_file: + output_file.write( + "# Generated by dev/run-memory-model-sweep.py; safe to delete/regenerate.\n" + ) + output_file.write( + f"# users={sweep_case.users} repos={sweep_case.repos} " + f"planned_grants={sweep_case.grants}\n" + ) + yaml.safe_dump(payload, output_file, sort_keys=False, allow_unicode=True) + generated.append(GeneratedMap(case=sweep_case, path=map_path)) + return generated + + +def write_manifest( + output_dir: Path, + generated_maps: Sequence[GeneratedMap], + service: ExternalServiceChoice, + endpoint: str, + inventory_repo_count: int, +) -> None: + manifest = { + "generated_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"), + "endpoint": endpoint, + "external_service": service_to_json(service), + "sourcegraph_inventory_repo_count": inventory_repo_count, + "maps": [ + { + "case": generated_map.case.name, + "users": generated_map.case.users, + "repos": generated_map.case.repos, + "grants": generated_map.case.grants, + "path": str(generated_map.path), + } + for generated_map in generated_maps + ], + } + write_json(output_dir / "manifest.json", manifest) + + +def run_sweep( + generated_maps: Sequence[GeneratedMap], + *, + endpoint: str, + access_token: str, + output_dir: Path, + command: str, + mode: RunMode, + parallelism: int, + explicit_permissions_batch_size: int, + http_timeout_seconds: float, + sample_interval: float, + trace: bool, + sourcegraph_user_count: int, + sourcegraph_inventory_repo_count: int, +) -> list[CommandRunResult]: + results: list[CommandRunResult] = [] + for generated_map in generated_maps: + print(f"Running {generated_map.case.name} ...", flush=True) + started = time.monotonic() + process_output_path = output_dir / f"{generated_map.case.name}.out" + arguments = command_arguments( + command, + generated_map.path, + mode=mode, + parallelism=parallelism, + explicit_permissions_batch_size=explicit_permissions_batch_size, + http_timeout_seconds=http_timeout_seconds, + sample_interval=sample_interval, + trace=trace, + ) + environment = command_environment(endpoint, access_token) + process = subprocess.run( + arguments, + cwd=Path.cwd(), + env=environment, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=False, + ) + elapsed_seconds = time.monotonic() - started + process_output_path.write_text(process.stdout, encoding="utf-8") + log_path = log_path_from_output(process.stdout) + run_record = read_run_record(log_path) + result = CommandRunResult( + generated_map=generated_map, + return_code=process.returncode, + elapsed_seconds=elapsed_seconds, + output_path=process_output_path, + log_path=log_path, + run_record=run_record, + ) + results.append(result) + write_results( + output_dir, + results, + inventory_repo_count=sourcegraph_inventory_repo_count, + sourcegraph_user_count=sourcegraph_user_count, + ) + print( + f" return_code={process.returncode} " + f"peak_rss_mb={memory_peak(result.run_record)} " + f"output={process_output_path}", + flush=True, + ) + if process.returncode != 0: + print("Stopping after first failed case.", file=sys.stderr) + break + return results + + +def command_arguments( + command: str, + map_path: Path, + *, + mode: RunMode, + parallelism: int, + explicit_permissions_batch_size: int, + http_timeout_seconds: float, + sample_interval: float, + trace: bool, +) -> list[str]: + arguments = [ + *shlex.split(command), + "--set", + str(map_path.resolve()), + "--full", + "--parallelism", + str(parallelism), + "--explicit-permissions-batch-size", + str(explicit_permissions_batch_size), + "--http-timeout-seconds", + f"{http_timeout_seconds:g}", + "--sample-interval", + f"{sample_interval:g}", + ] + if mode == "apply-no-backup": + arguments.extend(("--apply", "--no-backup")) + if trace: + arguments.append("--trace") + return arguments + + +def command_environment(endpoint: str, access_token: str) -> dict[str, str]: + environment = dict(os.environ) + environment["SRC_ENDPOINT"] = endpoint + environment["SRC_ACCESS_TOKEN"] = access_token + return environment + + +def log_path_from_output(output: str) -> Path | None: + match = LOG_PATH_PATTERN.search(output) + return Path(match.group(1)) if match else None + + +def read_run_record(log_path: Path | None) -> dict[str, Any] | None: + if log_path is None or not log_path.exists(): + return None + run_record: dict[str, Any] | None = None + with log_path.open(encoding="utf-8") as input_file: + for line in input_file: + try: + record = json.loads(line) + except json.JSONDecodeError: + continue + if not isinstance(record, dict): + continue + record_mapping = cast(dict[str, object], record) + if record_mapping.get("event") == "run" and record_mapping.get("phase") == "end": + run_record = cast(dict[str, Any], record_mapping) + return run_record + + +def write_results( + output_dir: Path, + results: Sequence[CommandRunResult], + inventory_repo_count: int, + sourcegraph_user_count: int, +) -> None: + result_payload = { + "generated_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"), + "results": [ + result_to_json(result, inventory_repo_count, sourcegraph_user_count) + for result in results + ], + "comparisons": [], + } + write_json(output_dir / "results.json", result_payload) + write_results_csv( + output_dir / "results.csv", + results, + inventory_repo_count, + sourcegraph_user_count, + ) + + +def result_to_json( + result: CommandRunResult, inventory_repo_count: int, sourcegraph_user_count: int +) -> dict[str, Any]: + run_record = result.run_record or {} + peak_rss_mb = memory_peak(result.run_record) + case = result.generated_map.case + return { + "variant": "candidate", + "iteration": 1, + "case": case.name, + "arguments": ["--set", str(result.generated_map.path), "--full"], + "return_code": result.return_code, + "elapsed_seconds": round(result.elapsed_seconds, 3), + "log_path": str(result.log_path) if result.log_path else None, + "run_directory": str(result.log_path.parent) if result.log_path else None, + "command": run_record.get("command") or "set_full", + "status": run_record.get("status"), + "jaeger_traces": [], + "memory": { + "peak_rss_mb": peak_rss_mb, + "sampled_peak_rss_mb": None, + "external_peak_rss_mb": None, + "resource_sample_count": 0, + "external_sample_count": 0, + "max_num_fds": run_record.get("num_fds"), + "max_num_threads": run_record.get("num_threads"), + "max_process_cpu_percent": None, + }, + "phase_memory": [], + "artifact_sizes": {}, + "workload": workload_json(case, inventory_repo_count, sourcegraph_user_count), + } + + +def workload_json( + sweep_case: SweepCase, inventory_repo_count: int, sourcegraph_user_count: int +) -> dict[str, int]: + return { + "selected_user_count": sweep_case.users, + "selected_repo_count": sweep_case.repos, + "selected_total_grants": sweep_case.grants, + "memory_model_user_count": sweep_case.users, + "memory_model_repo_count": sweep_case.repos, + "memory_model_grant_count": sweep_case.grants, + "sourcegraph_user_count": sourcegraph_user_count, + "sourcegraph_inventory_repo_count": inventory_repo_count, + } + + +def write_results_csv( + path: Path, + results: Sequence[CommandRunResult], + inventory_repo_count: int, + sourcegraph_user_count: int, +) -> None: + fieldnames = [ + "case", + "users", + "repos", + "grants", + "sourcegraph_users_discovered", + "sourcegraph_inventory_repo_count", + "return_code", + "elapsed_seconds", + "peak_rss_mb", + "log_path", + "map_path", + "output_path", + ] + with path.open("w", encoding="utf-8", newline="") as output_file: + writer = csv.DictWriter(output_file, fieldnames=fieldnames) + writer.writeheader() + for result in results: + case = result.generated_map.case + writer.writerow( + { + "case": case.name, + "users": case.users, + "repos": case.repos, + "grants": case.grants, + "sourcegraph_users_discovered": sourcegraph_user_count, + "sourcegraph_inventory_repo_count": inventory_repo_count, + "return_code": result.return_code, + "elapsed_seconds": f"{result.elapsed_seconds:.3f}", + "peak_rss_mb": memory_peak(result.run_record) or "", + "log_path": str(result.log_path) if result.log_path else "", + "map_path": str(result.generated_map.path), + "output_path": str(result.output_path), + } + ) + + +def memory_peak(run_record: Mapping[str, Any] | None) -> float | None: + if run_record is None: + return None + value = run_record.get("peak_rss_mb") + return float(value) if isinstance(value, int | float) else None + + +def write_json(path: Path, payload: object) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as output_file: + json.dump(payload, output_file, indent=2, sort_keys=True) + output_file.write("\n") + + +def service_to_json(service: ExternalServiceChoice) -> dict[str, object]: + return { + "graphql_id": service.graphql_id, + "database_id": service.database_id, + "display_name": service.display_name, + "kind": service.kind, + "url": service.url, + "repo_count": service.repo_count, + } + + +def default_output_dir(endpoint: str) -> Path: + host = urlsplit(endpoint).hostname or "sourcegraph" + safe_host = re.sub(r"[^A-Za-z0-9_.-]+", "-", host) + timestamp = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d-%H-%M-%S") + return Path("src-auth-perms-sync-runs") / safe_host / "memory-model-sweep" / timestamp + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/dev/sourcegraph-explicit-permissions-tracing.md b/dev/sourcegraph-explicit-permissions-tracing.md index fe8de9a..c61a399 100644 --- a/dev/sourcegraph-explicit-permissions-tracing.md +++ b/dev/sourcegraph-explicit-permissions-tracing.md @@ -95,8 +95,9 @@ To trace the full integration matrix, run the end-to-end script with its own tails each child run log and fetches all traced GraphQL Jaeger traces in the background while that child command is still running. The runner uses `src-py-lib` Config parsing, logging, Sourcegraph endpoint normalization, -`SourcegraphClient.fetch_jaeger_trace_summary()`, and a shared HTTP pool, so -trace summary and retry behavior match the CLI's Sourcegraph client: +`SourcegraphClient.fetch_jaeger_trace()`, `summarize_jaeger_trace()`, and a +shared HTTP pool, so trace fetch and summary behavior match the CLI's +Sourcegraph client: ```bash uv run python dev/test-end-to-end.py \ @@ -116,11 +117,35 @@ The runner writes trace summaries incrementally as JSON Lines. By default, it uses a sibling of `--results-json` or `--results-csv`, named `*-jaeger-traces.jsonl`. Override this with `--jaeger-trace-jsonl PATH`. +The runner also writes complete raw Jaeger trace payloads for in-depth +follow-up. By default, it uses a sibling directory named `*-jaeger-traces`. +Override this with `--jaeger-trace-dir PATH`. Each file is stored by variant, +iteration, case, and trace ID: + +```text +/ + candidate/ + iteration-0001/ + set-full-no-backup-apply/ + .json +``` + +Each raw trace file includes: + +- `trace_request`: CLI-side correlation metadata from the HTTP request and the + surrounding `graphql_query` event, including query name, page number, page + size, cursor presence, query byte count, variable names, response fields, + status, and timing. If `src-py-lib` later logs sanitized GraphQL variable + values, the same field will include them as `variables`, `input_variables`, + or `variable_values`. +- `jaeger_summary`: compact hot-operation and GraphQL-operation summary. +- `jaeger_trace`: the complete Jaeger trace JSON returned by Sourcegraph. + The shared `src-py-lib` `stream_jaeger_trace_summaries()` helper now fetches in parallel for in-process Sourcegraph clients. The end-to-end script still uses a bounded global worker pool because the traced requests happen in child processes and are discovered by tailing their JSON logs. Tune this with -`--jaeger-trace-parallelism N` (default 16). The runner drains outstanding +`--jaeger-trace-parallelism N` (default 8). The runner drains outstanding background collectors once at the end, before it writes JSON/CSV results, so Jaeger collection does not add a blocking phase between child cases. @@ -136,6 +161,46 @@ For each tested batch size and parallelism, record: `sql.conn.query`, and `database.PermsStore.LoadUserPermissions` - retries/timeouts from the CLI log +## Monitor Sourcegraph pod load during e2e runs + +Prefer running the end-to-end script as the single orchestrator. It can start +the Sourcegraph pod/Postgres monitor, collect Jaeger traces in parallel with +each child CLI command, and write all artifact paths into the result JSON: + +```bash +uv run python dev/test-end-to-end.py \ + --trace \ + --monitor-sourcegraph-load \ + --sample-interval 0 \ + --external-sample-interval 0 \ + --results-json /tmp/src-auth-perms-sync-end-to-end-trace.json \ + --results-csv /tmp/src-auth-perms-sync-end-to-end-trace.csv +``` + +By default, monitor output is written beside `--results-json` or +`--results-csv` as `*-sourcegraph-load`, and the monitor's own stdout/stderr is +written to `*-sourcegraph-load.log`. Override the location with +`--monitor-output-dir PATH`. Tune Kubernetes targets and sample intervals with +the `--monitor-*` flags if the test namespace or pod names differ. + +The lower-level helper remains available for focused profiling outside a full +e2e run: + +```bash +dev/monitor-sourcegraph-load.sh \ + --namespace m \ + --output-dir /tmp/src-auth-perms-sync-sourcegraph-load-$(date -u +%Y%m%d-%H%M%S) +``` + +Stop the helper with Ctrl-C after the e2e run finishes, or add +`--duration-seconds N`. The script samples Kubernetes CPU/memory, frontend and +Postgres processes, cgroup CPU/memory pressure, Postgres active queries/waits/locks, +`pg_stat_statements` when enabled, and frontend logs. Outputs are timestamped +files in the selected directory. On startup, it runs `CREATE EXTENSION IF NOT +EXISTS pg_stat_statements` and `pg_stat_statements_reset()` through +`kubectl exec` against `pod/pgsql-0`, so the statement summary starts clean for +the monitored run. + In a traced sgdev end-to-end run after the matrix was trimmed to avoid overlapping code paths, all 36 cases passed. Child command time summed to about 1,126 seconds. The JSONL trace summary file contained 3,256 GraphQL trace diff --git a/dev/test-end-to-end.py b/dev/test-end-to-end.py index 90e1c73..97f5bb3 100755 --- a/dev/test-end-to-end.py +++ b/dev/test-end-to-end.py @@ -13,19 +13,22 @@ from __future__ import annotations +import contextlib import csv import datetime +import heapq import json import os import re import shlex +import signal import statistics import subprocess import sys import threading import time from collections.abc import Iterable, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import Future from concurrent.futures import wait as wait_for_futures from dataclasses import dataclass from pathlib import Path @@ -35,37 +38,44 @@ import src_py_lib as src from src_py_lib.clients.sourcegraph import ( DEFAULT_SOURCEGRAPH_ENDPOINT, - JAEGER_TRACE_RETRY_DELAYS_SECONDS, sourcegraph_trace_from_headers, + summarize_jaeger_trace, ) LOG_PATH_PATTERN = re.compile(r"Writing log events to (.+?/log\.json)\.") +SAFE_PATH_PART_PATTERN = re.compile(r"[^A-Za-z0-9_.-]+") DEFAULT_FUTURE_DATE = "2099-01-01" REMOVED_SRC_AUTH_PERMS_SYNC_ENVIRONMENT_PREFIX = "SRC_AUTH_PERMS_SYNC_" DEFAULT_SAMPLE_INTERVAL_SECONDS = 1.0 DEFAULT_REPEAT_COUNT = 1 DEFAULT_JAEGER_TRACE_LIMIT: int | None = None -DEFAULT_JAEGER_TRACE_PARALLELISM = 16 +DEFAULT_JAEGER_TRACE_PARALLELISM = 8 +DEFAULT_JAEGER_INITIAL_DELAY_SECONDS = 35.0 +DEFAULT_JAEGER_RETRY_DELAYS_SECONDS = ( + 2.0, + 5.0, + 10.0, + 20.0, + 30.0, + 60.0, + 60.0, + 60.0, + 60.0, + 60.0, + 60.0, +) DEFAULT_PARALLELISM = 4 DEFAULT_FULL_RESTORE_PARALLELISM = 1 +DEFAULT_INCLUDE_REDUNDANT_SCALE_CASES = False DEFAULT_MEMORY_SUMMARY_LIMIT = 20 DEFAULT_SRC_AUTH_PERMS_SYNC_COMMAND = "uv run src-auth-perms-sync" -WORKLOAD_FIELDS = ( - "user_count", - "total_users", - "total_users_scanned", - "repo_count", - "repos_with_explicit_grants", - "total_grants", - "mapping_count", - "plan_size", - "payload_count", - "target_organizations", - "desired_memberships", - "mutations_succeeded", - "mutations_failed", - "mutations_canceled", -) +DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE = "m" +DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS = 5 +DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS = 10 +DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS = 30 +DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET = "deployment/sourcegraph-frontend" +DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET = "pod/pgsql-0" +DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND = "psql -X -U sg -d sg" def format_jaeger_retry_delays(delays: Sequence[float]) -> str: @@ -160,6 +170,16 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): f"(default: {DEFAULT_FULL_RESTORE_PARALLELISM})" ), ) + include_redundant_scale_cases: bool = src.config_field( + default=DEFAULT_INCLUDE_REDUNDANT_SCALE_CASES, + env_var="SRC_AUTH_PERMS_SYNC_E2E_INCLUDE_REDUNDANT_SCALE_CASES", + cli_flag="--include-redundant-scale-cases", + cli_action="store_true", + help=( + "Also run older overlapping full-scale cases. Default keeps one heavy full " + "snapshot path and uses smaller cases for overlapping coverage." + ), + ) allow_non_test_endpoint: bool = src.config_field( default=False, env_var="SRC_AUTH_PERMS_SYNC_E2E_ALLOW_NON_TEST_ENDPOINT", @@ -203,6 +223,17 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): f"(default: {DEFAULT_JAEGER_TRACE_PARALLELISM})" ), ) + jaeger_initial_delay_seconds: float = src.config_field( + default=DEFAULT_JAEGER_INITIAL_DELAY_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_INITIAL_DELAY_SECONDS", + cli_flag="--jaeger-initial-delay-seconds", + metavar="SECONDS", + ge=0, + help=( + "Seconds to wait before first fetching each Jaeger trace, to allow OTel tail " + f"sampling to decide (default: {DEFAULT_JAEGER_INITIAL_DELAY_SECONDS:g})" + ), + ) jaeger_trace_jsonl: Path | None = src.config_field( default=None, env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_TRACE_JSONL", @@ -213,14 +244,26 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): "of --results-json or --results-csv when --trace is set." ), ) + jaeger_trace_directory: Path | None = src.config_field( + default=None, + env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_TRACE_DIR", + cli_flag="--jaeger-trace-dir", + metavar="PATH", + help=( + "Directory where complete raw Jaeger trace JSON files are written. Defaults " + "to a sibling directory of --results-json or --results-csv when --trace is set." + ), + ) jaeger_retry_delays: tuple[float, ...] = src.config_field( - default=JAEGER_TRACE_RETRY_DELAYS_SECONDS, + default=DEFAULT_JAEGER_RETRY_DELAYS_SECONDS, env_var="SRC_AUTH_PERMS_SYNC_E2E_JAEGER_RETRY_DELAYS", cli_flag="--jaeger-retry-delays", metavar="SECONDS[,SECONDS...]", help=( - "Comma-separated retry delays for Jaeger trace lookup lag " - f"(default: {format_jaeger_retry_delays(JAEGER_TRACE_RETRY_DELAYS_SECONDS)})" + "Comma-separated delays between queued Jaeger trace fetch retries. " + "Each value schedules one retry after the initial fetch; add more values " + "to try for longer " + f"(default: {format_jaeger_retry_delays(DEFAULT_JAEGER_RETRY_DELAYS_SECONDS)})" ), ) sample_interval: float = src.config_field( @@ -271,6 +314,103 @@ class EndToEndConfig(src.SourcegraphClientConfig, src.LoggingConfig): "beside it as *-phases.csv" ), ) + monitor_sourcegraph_load: bool = src.config_field( + default=False, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_SOURCEGRAPH_LOAD", + cli_flag="--monitor-sourcegraph-load", + cli_action="store_true", + help=( + "Start the Sourcegraph pod/Postgres load monitor for this e2e run and write " + "its output beside the result artifacts." + ), + ) + sourcegraph_monitor_namespace: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_NAMESPACE", + cli_flag="--monitor-namespace", + metavar="NAME", + help=( + "Kubernetes namespace for Sourcegraph load monitoring " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_NAMESPACE})" + ), + ) + sourcegraph_monitor_output_dir: Path | None = src.config_field( + default=None, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_OUTPUT_DIR", + cli_flag="--monitor-output-dir", + metavar="PATH", + help="Directory for Sourcegraph load monitor output; defaults beside result artifacts.", + ) + sourcegraph_monitor_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_INTERVAL_SECONDS", + cli_flag="--monitor-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "Pod/process/cgroup monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_postgres_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_POSTGRES_INTERVAL_SECONDS", + cli_flag="--monitor-postgres-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "Postgres activity monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_statements_interval_seconds: int = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_STATEMENTS_INTERVAL_SECONDS", + cli_flag="--monitor-statements-interval-seconds", + metavar="SECONDS", + ge=1, + help=( + "pg_stat_statements monitor interval in seconds " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_STATEMENTS_INTERVAL_SECONDS})" + ), + ) + sourcegraph_monitor_frontend_target: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_FRONTEND_TARGET", + cli_flag="--monitor-frontend-target", + metavar="TARGET", + help=( + "kubectl target for Sourcegraph frontend " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_FRONTEND_TARGET})" + ), + ) + sourcegraph_monitor_postgres_target: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_POSTGRES_TARGET", + cli_flag="--monitor-postgres-target", + metavar="TARGET", + help=( + "kubectl target for Sourcegraph Postgres " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_POSTGRES_TARGET})" + ), + ) + sourcegraph_monitor_psql_command: str = src.config_field( + default=DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_PSQL_COMMAND", + cli_flag="--monitor-psql-command", + metavar="COMMAND", + help=( + "psql command to run inside the Postgres pod " + f"(default: {DEFAULT_SOURCEGRAPH_MONITOR_PSQL_COMMAND})" + ), + ) + sourcegraph_monitor_no_logs: bool = src.config_field( + default=False, + env_var="SRC_AUTH_PERMS_SYNC_E2E_MONITOR_NO_LOGS", + cli_flag="--monitor-no-logs", + cli_action="store_true", + help="Do not stream frontend logs while Sourcegraph load monitoring is enabled.", + ) fail_on_memory_regression_percent: float | None = src.config_field( default=None, env_var="SRC_AUTH_PERMS_SYNC_E2E_FAIL_ON_MEMORY_REGRESSION_PERCENT", @@ -431,22 +571,136 @@ def sample_once(self) -> None: self.peak_rss_mb = max_optional_float(self.peak_rss_mb, rss_mb) +class SourcegraphLoadMonitor: + """Run the Sourcegraph pod/Postgres monitor for the duration of the e2e suite.""" + + def __init__(self, config: EndToEndConfig, output_dir: Path) -> None: + self.config = config + self.output_dir = output_dir + self.log_path = output_dir.with_name(f"{output_dir.name}.log") + self._log_file: TextIO | None = None + self._process: subprocess.Popen[str] | None = None + + def start(self) -> None: + script_path = sourcegraph_monitor_script_path() + if not script_path.exists(): + raise RuntimeError(f"Sourcegraph load monitor script not found: {script_path}") + self.output_dir.parent.mkdir(parents=True, exist_ok=True) + self.log_path.parent.mkdir(parents=True, exist_ok=True) + command = [ + str(script_path), + "--namespace", + self.config.sourcegraph_monitor_namespace, + "--output-dir", + str(self.output_dir), + "--interval-seconds", + str(self.config.sourcegraph_monitor_interval_seconds), + "--postgres-interval-seconds", + str(self.config.sourcegraph_monitor_postgres_interval_seconds), + "--statements-interval-seconds", + str(self.config.sourcegraph_monitor_statements_interval_seconds), + "--frontend-target", + self.config.sourcegraph_monitor_frontend_target, + "--postgres-target", + self.config.sourcegraph_monitor_postgres_target, + "--psql-command", + self.config.sourcegraph_monitor_psql_command, + ] + if self.config.sourcegraph_monitor_no_logs: + command.append("--no-logs") + print(f"Starting Sourcegraph load monitor: {self.output_dir}") + self._log_file = self.log_path.open("w", encoding="utf-8") + self._process = subprocess.Popen( # noqa: S603 - command is trusted test config. + command, + cwd=Path.cwd(), + stdout=self._log_file, + stderr=subprocess.STDOUT, + text=True, + start_new_session=True, + ) + self._wait_until_started() + + def stop(self) -> None: + process = self._process + if process is None: + self._close_log_file() + return + if process.poll() is None: + with contextlib.suppress(ProcessLookupError): + os.killpg(process.pid, signal.SIGTERM) + try: + process.wait(timeout=15) + except subprocess.TimeoutExpired: + with contextlib.suppress(ProcessLookupError): + os.killpg(process.pid, signal.SIGKILL) + process.wait(timeout=15) + return_code = process.returncode + self._close_log_file() + if return_code not in {0, -15, 143}: + print( + f"Sourcegraph load monitor exited with status {return_code}; see {self.log_path}", + file=sys.stderr, + ) + else: + print(f"Stopped Sourcegraph load monitor. Output: {self.output_dir}") + + def _wait_until_started(self) -> None: + process = self._process + if process is None: + return + deadline = time.monotonic() + 60 + while time.monotonic() < deadline: + if process.poll() is not None: + raise RuntimeError( + f"Sourcegraph load monitor exited before startup completed; see {self.log_path}" + ) + if self.log_path.exists() and "Started kubectl-top" in self.log_path.read_text( + encoding="utf-8", errors="ignore" + ): + return + time.sleep(0.2) + raise RuntimeError( + f"Timed out waiting for Sourcegraph load monitor startup; see {self.log_path}" + ) + + def _close_log_file(self) -> None: + if self._log_file is not None: + self._log_file.close() + self._log_file = None + + +@dataclass +class JaegerTraceFetchTask: + """One trace fetch request that can be retried across the whole e2e run.""" + + trace_request: dict[str, Any] + future: Future[dict[str, Any]] + fetch_attempts: int = 0 + first_fetch_at: str | None = None + last_fetch_at: str | None = None + + class JaegerTraceFetchPool: - """Fetch Sourcegraph Jaeger trace summaries through one bounded HTTP pool.""" + """Fetch Sourcegraph Jaeger traces through one bounded retry queue.""" def __init__( self, config: EndToEndConfig, *, parallelism: int, + initial_delay_seconds: float, retry_delays_seconds: Sequence[float], jsonl_path: Path | None, + trace_directory: Path | None, ) -> None: + self.initial_delay_seconds = initial_delay_seconds self.retry_delays_seconds = tuple(retry_delays_seconds) - self._executor = ThreadPoolExecutor( - max_workers=parallelism, - thread_name_prefix="JaegerTraceFetch", - ) + self.max_fetch_attempts = len(self.retry_delays_seconds) + 1 + self._trace_directory = trace_directory + self._tasks: list[tuple[float, int, JaegerTraceFetchTask]] = [] + self._condition = threading.Condition() + self._sequence = 0 + self._closed = False self._jsonl_file: TextIO | None = None self._lock = threading.Lock() http = src.HTTPClient( @@ -459,36 +713,166 @@ def __init__( jsonl_path.parent.mkdir(parents=True, exist_ok=True) self._jsonl_file = jsonl_path.open("w", encoding="utf-8") print(f"Writing Jaeger trace summaries incrementally to {jsonl_path}") + if self._trace_directory is not None: + self._trace_directory.mkdir(parents=True, exist_ok=True) + print(f"Writing complete Jaeger traces to {self._trace_directory}") + self._workers = [ + threading.Thread( + target=self._worker, + name=f"JaegerTraceFetch-{worker_number}", + daemon=True, + ) + for worker_number in range(1, parallelism + 1) + ] + for worker in self._workers: + worker.start() def submit( self, trace_request: dict[str, Any], collector: JaegerTraceCollector, ) -> Future[dict[str, Any]]: - future = src.submit_with_log_context(self._executor, self._fetch_summary, trace_request) + future: Future[dict[str, Any]] = Future() future.add_done_callback(lambda completed: self._record_summary(collector, completed)) + task = JaegerTraceFetchTask( + trace_request=trace_request, + future=future, + ) + self._schedule(task, self.initial_delay_seconds) return future def close(self) -> None: - self._executor.shutdown(wait=True) + with self._condition: + self._closed = True + self._condition.notify_all() + for worker in self._workers: + worker.join() self._client.http.close() if self._jsonl_file is not None: self._jsonl_file.close() - def _fetch_summary(self, trace_request: dict[str, Any]) -> dict[str, Any]: + def _schedule(self, task: JaegerTraceFetchTask, delay_seconds: float) -> None: + with self._condition: + self._sequence += 1 + heapq.heappush( + self._tasks, + (time.monotonic() + delay_seconds, self._sequence, task), + ) + self._condition.notify() + + def _worker(self) -> None: + while True: + task = self._next_ready_task() + if task is None: + return + self._process(task) + + def _next_ready_task(self) -> JaegerTraceFetchTask | None: + with self._condition: + while True: + if self._closed and not self._tasks: + return None + if not self._tasks: + self._condition.wait() + continue + ready_at, _sequence, task = self._tasks[0] + delay_seconds = ready_at - time.monotonic() + if delay_seconds > 0: + self._condition.wait(delay_seconds) + continue + heapq.heappop(self._tasks) + return task + + def _process(self, task: JaegerTraceFetchTask) -> None: + if task.future.done(): + return + summary = self._fetch_summary(task) + if summary.get("jaeger_found") is True or not self._should_retry(task, summary): + task.future.set_result(summary) + return + self._schedule(task, self._retry_delay_seconds(task.fetch_attempts)) + + def _fetch_summary(self, task: JaegerTraceFetchTask) -> dict[str, Any]: + task.fetch_attempts += 1 + now = datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds") + if task.first_fetch_at is None: + task.first_fetch_at = now + task.last_fetch_at = now try: - trace = sourcegraph_trace_from_request(trace_request) - summary = self._client.fetch_jaeger_trace_summary( - trace, - retry_delays_seconds=self.retry_delays_seconds, - ).to_json() - return {**trace_request, **summary} + trace = sourcegraph_trace_from_request(task.trace_request) + jaeger_trace = self._client.fetch_jaeger_trace( + trace.trace_id, + retry_delays_seconds=(0.0,), + ) + summary = summarize_jaeger_trace(trace, jaeger_trace).to_json() + try: + trace_path = self._write_complete_trace(task, jaeger_trace, summary) + if trace_path is not None: + summary["jaeger_trace_path"] = str(trace_path) + except OSError as write_error: + summary["jaeger_trace_write_error"] = f"{type(write_error).__name__}: {write_error}" + return self._with_fetch_fields(task, summary) except Exception as exception: # noqa: BLE001 - keep long-running evidence collection alive. - return { - **trace_request, - "jaeger_found": False, - "error": f"{type(exception).__name__}: {exception}", - } + return self._with_fetch_fields( + task, + { + **task.trace_request, + "jaeger_found": False, + "error": f"{type(exception).__name__}: {exception}", + }, + ) + + def _with_fetch_fields( + self, task: JaegerTraceFetchTask, summary: dict[str, Any] + ) -> dict[str, Any]: + return { + **task.trace_request, + **summary, + "fetch_attempts": task.fetch_attempts, + "first_fetch_at": task.first_fetch_at, + "last_fetch_at": task.last_fetch_at, + "max_fetch_attempts": self.max_fetch_attempts, + } + + def _write_complete_trace( + self, + task: JaegerTraceFetchTask, + jaeger_trace: dict[str, Any], + summary: dict[str, Any], + ) -> Path | None: + if self._trace_directory is None: + return None + path = complete_jaeger_trace_path(self._trace_directory, task.trace_request) + payload = { + "collected_at": task.last_fetch_at, + "fetch_attempts": task.fetch_attempts, + "max_fetch_attempts": self.max_fetch_attempts, + "trace_request": task.trace_request, + "jaeger_summary": summary, + "jaeger_trace": jaeger_trace, + } + path.parent.mkdir(parents=True, exist_ok=True) + temporary_path = path.with_name( + f".{path.name}.tmp-{threading.get_ident()}-{time.monotonic_ns()}" + ) + temporary_path.write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + temporary_path.replace(path) + return path + + def _should_retry(self, task: JaegerTraceFetchTask, summary: dict[str, Any]) -> bool: + if self._closed or task.fetch_attempts >= self.max_fetch_attempts: + return False + error = str(summary.get("error") or "") + return error.startswith(("HTTP 404", "HTTP 502", "HTTP 503", "HTTP 504")) + + def _retry_delay_seconds(self, fetch_attempts: int) -> float: + if not self.retry_delays_seconds: + return 0.0 + delay_index = min(fetch_attempts - 1, len(self.retry_delays_seconds) - 1) + return self.retry_delays_seconds[delay_index] def _record_summary( self, @@ -527,6 +911,8 @@ def __init__( self.iteration = iteration self.case_name = case_name self.summaries: list[dict[str, Any]] = [] + self._graphql_queries_by_span: dict[tuple[str, str], dict[str, Any]] = {} + self._trace_requests_by_graphql_span: dict[tuple[str, str], dict[str, Any]] = {} self._requests_by_trace_id: dict[str, dict[str, Any]] = {} self._queued_trace_ids: set[str] = set() self._futures: list[Future[dict[str, Any]]] = [] @@ -599,15 +985,22 @@ def _record_line(self, line: str) -> None: return if not isinstance(record, dict): return + self._record_graphql_query_metadata(cast(dict[str, Any], record)) trace_request = graphql_trace_request_from_record(cast(dict[str, Any], record)) if trace_request is None: return trace_request.update( {"variant": self.variant, "iteration": self.iteration, "case": self.case_name} ) + graphql_span_key = self._graphql_span_key_for_http_record(cast(dict[str, Any], record)) trace_id = trace_request["trace_id"] submit_request: dict[str, Any] | None = None with self._lock: + if graphql_span_key is not None: + graphql_query = self._graphql_queries_by_span.get(graphql_span_key) + if graphql_query is not None: + trace_request["graphql_query"] = dict(graphql_query) + self._trace_requests_by_graphql_span[graphql_span_key] = trace_request existing_request = self._requests_by_trace_id.get(trace_id) if existing_request is None or trace_summary_duration_ms( trace_request @@ -621,6 +1014,29 @@ def _record_line(self, line: str) -> None: with self._lock: self._futures.append(future) + def _record_graphql_query_metadata(self, record: dict[str, Any]) -> None: + metadata = graphql_query_metadata_from_record(record) + if metadata is None: + return + span_key = graphql_query_span_key(record) + if span_key is None: + return + with self._lock: + existing_metadata = self._graphql_queries_by_span.get(span_key, {}) + merged_metadata = existing_metadata | metadata + self._graphql_queries_by_span[span_key] = merged_metadata + trace_request = self._trace_requests_by_graphql_span.get(span_key) + if trace_request is not None: + trace_request["graphql_query"] = dict(merged_metadata) + + @staticmethod + def _graphql_span_key_for_http_record(record: dict[str, Any]) -> tuple[str, str] | None: + trace_id = optional_string(record.get("trace")) + parent_span_id = optional_string(record.get("parent_span")) + if trace_id is None or parent_span_id is None: + return None + return trace_id, parent_span_id + def _submit_limited_requests(self) -> None: if self.limit is None: return @@ -879,14 +1295,22 @@ def run_end_to_end(config: EndToEndConfig) -> None: all_failures: list[str] = [] all_jaeger_collectors: list[JaegerTraceCollector] = [] jaeger_trace_fetch_pool = create_jaeger_trace_fetch_pool(config) + sourcegraph_load_monitor = create_sourcegraph_load_monitor(config) latest_baseline_repositories: set[str] = set() try: + if sourcegraph_load_monitor is not None: + sourcegraph_load_monitor.start() with src.event( "end_to_end_matrix", repeat=config.repeat, variant_count=len(variants), trace=config.trace, + sourcegraph_load_monitor=sourcegraph_load_monitor is not None, ) as matrix_summary: + if sourcegraph_load_monitor is not None: + matrix_summary["sourcegraph_load_monitor_dir"] = str( + sourcegraph_load_monitor.output_dir + ) for iteration in range(1, config.repeat + 1): for variant in variants: with src.stage("matrix_variant", variant=variant.name, iteration=iteration): @@ -915,6 +1339,8 @@ def run_end_to_end(config: EndToEndConfig) -> None: wait_for_jaeger_trace_collectors(all_jaeger_collectors) if jaeger_trace_fetch_pool is not None: jaeger_trace_fetch_pool.close() + if sourcegraph_load_monitor is not None: + sourcegraph_load_monitor.stop() if all_failures: print("\nFailures:", file=sys.stderr) for failure in all_failures: @@ -928,7 +1354,7 @@ def run_end_to_end(config: EndToEndConfig) -> None: print_phase_memory_summary(all_results, config.memory_summary_limit) comparisons = compare_variants(all_results) print_comparison_summary(comparisons) - write_results_files(all_results, comparisons, config) + write_results_files(all_results, comparisons, config, sourcegraph_load_monitor) raise_for_memory_regressions(comparisons, config) @@ -955,8 +1381,10 @@ def create_jaeger_trace_fetch_pool( return JaegerTraceFetchPool( config, parallelism=config.jaeger_trace_parallelism, + initial_delay_seconds=config.jaeger_initial_delay_seconds, retry_delays_seconds=config.jaeger_retry_delays, jsonl_path=jaeger_trace_jsonl_path(config), + trace_directory=jaeger_trace_directory(config), ) @@ -971,6 +1399,56 @@ def jaeger_trace_jsonl_path(config: EndToEndConfig) -> Path | None: return Path("/tmp") / f"src-auth-perms-sync-end-to-end-jaeger-traces-{stamp}.jsonl" +def jaeger_trace_directory(config: EndToEndConfig) -> Path: + """Return the directory where complete raw Jaeger traces should be stored.""" + if config.jaeger_trace_directory is not None: + return config.jaeger_trace_directory + anchor = config.results_json or config.results_csv + if anchor is not None: + return anchor.with_name(f"{anchor.stem}-jaeger-traces") + stamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") + return Path("/tmp") / f"src-auth-perms-sync-end-to-end-jaeger-traces-{stamp}" + + +def create_sourcegraph_load_monitor(config: EndToEndConfig) -> SourcegraphLoadMonitor | None: + """Return the Sourcegraph load monitor for this run, if enabled.""" + if not config.monitor_sourcegraph_load: + return None + return SourcegraphLoadMonitor(config, sourcegraph_monitor_output_dir(config)) + + +def sourcegraph_monitor_output_dir(config: EndToEndConfig) -> Path: + """Return where Sourcegraph pod/Postgres monitor artifacts should be stored.""" + if config.sourcegraph_monitor_output_dir is not None: + return config.sourcegraph_monitor_output_dir + anchor = config.results_json or config.results_csv + if anchor is not None: + return anchor.with_name(f"{anchor.stem}-sourcegraph-load") + stamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") + return Path("/tmp") / f"src-auth-perms-sync-end-to-end-sourcegraph-load-{stamp}" + + +def sourcegraph_monitor_script_path() -> Path: + """Return the lower-level monitor script used by the e2e orchestrator.""" + return Path(__file__).resolve().with_name("monitor-sourcegraph-load.sh") + + +def complete_jaeger_trace_path(trace_directory: Path, trace_request: dict[str, Any]) -> Path: + """Return the stable per-trace path for a complete Jaeger trace payload.""" + variant = safe_path_part(trace_request.get("variant"), default="variant") + iteration = int_field(trace_request, "iteration") or 0 + case_name = safe_path_part(trace_request.get("case"), default="case") + trace_id = safe_path_part(trace_request.get("trace_id"), default="trace") + return trace_directory / variant / f"iteration-{iteration:04d}" / case_name / f"{trace_id}.json" + + +def safe_path_part(value: object, *, default: str) -> str: + """Return a filesystem-safe path segment for generated trace artifacts.""" + text = str(value) if value is not None else "" + safe_text = SAFE_PATH_PART_PATTERN.sub("-", text).strip("-.") + return safe_text[:120] or default + + def command_environment(config: EndToEndConfig) -> dict[str, str]: """Return a deterministic child environment for CLI config parsing.""" environment = dict(os.environ) @@ -1187,7 +1665,7 @@ def read_only_cases(config: EndToEndConfig) -> list[CommandCase]: ), CommandCase( name="get-sync-saml-orgs-dry-run", - arguments=("--get", "--sync-saml-orgs"), + arguments=("--get", "--user", config.user, "--sync-saml-orgs"), expected_log_command="get_sync_saml_orgs", must_contain=("Wrote before-snapshot", "Dry run complete"), ), @@ -1325,29 +1803,30 @@ def run_full_apply_cases(config: EndToEndConfig, runner: CommandPermutationRunne ) baseline_snapshot = snapshot_path(dry_run_result) - try: - runner.run( - CommandCase( - name="set-full-apply", - arguments=( - "--set", - "--apply", - "--parallelism", - str(config.parallelism), - ), - expected_log_command="set_full", - must_contain=("VALIDATION OK",), + if config.include_redundant_scale_cases: + try: + runner.run( + CommandCase( + name="set-full-apply", + arguments=( + "--set", + "--apply", + "--parallelism", + str(config.parallelism), + ), + expected_log_command="set_full", + must_contain=("VALIDATION OK",), + ) ) - ) - finally: - runner.run( - restore_full_apply_case( - "restore-full-apply-cleanup", - baseline_snapshot, - config, - no_backup=False, + finally: + runner.run( + restore_full_apply_case( + "restore-full-apply-cleanup", + baseline_snapshot, + config, + no_backup=False, + ) ) - ) try: runner.run( @@ -1374,14 +1853,14 @@ def run_full_apply_cases(config: EndToEndConfig, runner: CommandPermutationRunne ) ) - # Covers the combined set+SAML dispatch and SAML dry-run path without - # repeating the full set apply and full restore cleanup paths, which are - # already covered above. + # Covers combined set+SAML dispatch and SAML dry-run with a user-scoped + # set path, so the default suite keeps only one expensive full-snapshot + # case. Pass --include-redundant-scale-cases to restore older overlap. runner.run( CommandCase( - name="set-full-sync-saml-orgs-dry-run", - arguments=("--set", "--sync-saml-orgs"), - expected_log_command="set_full_sync_saml_orgs", + name="set-user-sync-saml-orgs-dry-run", + arguments=("--set", "maps.yaml", "--user", config.user, "--sync-saml-orgs"), + expected_log_command="set_user_sync_saml_orgs", must_contain=("Dry run complete",), ) ) @@ -1602,20 +2081,178 @@ def parse_log_timestamp(value: object) -> datetime.datetime | None: def workload_from_records(records: list[dict[str, Any]]) -> dict[str, int | float | str]: - """Collect stable workload-size fields so memory can be normalized.""" + """Collect named workload dimensions from structured log records. + + Earlier e2e summaries used raw field names from unrelated events, which made + values like `total_users` and `repo_count` ambiguous. Keep this summary + event-aware so each key says what it counts. + """ workload: dict[str, int | float | str] = {} for record in records: - for field_name in WORKLOAD_FIELDS: - value = record.get(field_name) - if isinstance(value, int | float): - old_value = workload.get(field_name) - if not isinstance(old_value, int | float) or value > old_value: - workload[field_name] = value - elif isinstance(value, str) and field_name not in workload: - workload[field_name] = value + event_name = optional_string(record.get("event")) + phase = optional_string(record.get("phase")) + if event_name == "capture_explicit_grants": + record_workload_max(workload, "sourcegraph_user_count", record.get("total_users")) + if phase == "end": + record_workload_max(workload, "captured_user_count", record.get("user_count")) + elif event_name in {"build_snapshot", "build_user_scoped_snapshot"} and phase == "end": + record_workload_max(workload, "snapshot_user_count_max", record.get("user_count")) + record_workload_max( + workload, + "snapshot_repos_with_explicit_grants_max", + record.get("repos_with_explicit_grants"), + ) + record_workload_max(workload, "snapshot_total_grants_max", record.get("total_grants")) + record_workload_max(workload, "captured_user_count", record.get("user_count")) + elif event_name == "user_explicit_repos_batch_fetch" and phase == "end": + record_workload_max(workload, "batch_user_count_max", record.get("user_count")) + record_workload_max( + workload, + "batch_fetched_grant_count_max", + record.get("fetched_grant_count") + if "fetched_grant_count" in record + else record.get("repo_count"), + ) + elif event_name == "load_repos_by_external_service" and phase == "end": + record_workload_max(workload, "loaded_repo_count", record.get("repo_count")) + record_workload_max( + workload, + "expected_repo_count", + record.get("expected_repo_count"), + ) + elif event_name == "apply_username_overwrites": + record_workload_max(workload, "apply_payload_count", record.get("payload_count")) + record_workload_max( + workload, + "apply_payload_grant_count", + record.get("payload_grant_count") + if "payload_grant_count" in record + else record.get("total_users"), + ) + record_workload_max(workload, "parallelism", record.get("parallelism")) + if phase == "end": + record_workload_max( + workload, + "apply_mutations_succeeded", + record.get("succeeded"), + ) + record_workload_max(workload, "apply_mutations_failed", record.get("failed")) + record_workload_max(workload, "apply_mutations_canceled", record.get("canceled")) + elif ( + event_name + in { + "cmd_get", + "cmd_restore", + "cmd_restore_user_scoped", + "cmd_set", + "cmd_set_additive_user", + "cmd_set_additive_users_without_explicit_perms", + } + and phase == "end" + ): + record_command_workload(workload, record) + elif event_name in {"sync_saml_orgs", "cmd_sync_saml_orgs"} and phase == "end": + record_workload_max( + workload, + "target_organizations", + record.get("target_organizations"), + ) + record_workload_max(workload, "desired_memberships", record.get("desired_memberships")) + + record_workload_model_dimensions(workload) return workload +def record_command_workload(workload: dict[str, int | float | str], record: dict[str, Any]) -> None: + """Copy command-level counts using names that preserve their meaning.""" + event_name = optional_string(record.get("event")) + repo_count = record.get("repo_count") + total_grants = record.get("total_grants") + if event_name == "cmd_set": + record_workload_max(workload, "planned_repo_count", repo_count) + record_workload_max(workload, "planned_total_grants", total_grants) + elif event_name == "cmd_get": + record_workload_max(workload, "selected_user_count", record.get("user_count")) + record_workload_max(workload, "selected_total_grants", total_grants) + elif event_name == "cmd_restore": + record_workload_max(workload, "restore_snapshot_repo_count", record.get("snapshot_repos")) + record_workload_max( + workload, + "restore_snapshot_total_grants", + record.get("snapshot_grants"), + ) + elif event_name == "cmd_set_additive_user": + record_workload_max(workload, "selected_user_count", record.get("user_count")) + record_workload_max(workload, "planned_repo_count", repo_count) + record_workload_max(workload, "planned_total_grants", total_grants) + + record_workload_max(workload, "mapping_count", record.get("mapping_count")) + record_workload_max(workload, "mutations_succeeded", record.get("mutations_succeeded")) + record_workload_max(workload, "mutations_failed", record.get("mutations_failed")) + record_workload_max(workload, "mutations_canceled", record.get("mutations_canceled")) + + +def record_workload_model_dimensions(workload: dict[str, int | float | str]) -> None: + """Add the canonical dimensions used by memory modeling.""" + user_count = max_workload_number( + workload, + ( + "selected_user_count", + "captured_user_count", + "snapshot_user_count_max", + "sourcegraph_user_count", + ), + ) + repo_count = max_workload_number( + workload, + ( + "planned_repo_count", + "restore_snapshot_repo_count", + "snapshot_repos_with_explicit_grants_max", + "loaded_repo_count", + ), + ) + grant_count = max_workload_number( + workload, + ( + "planned_total_grants", + "restore_snapshot_total_grants", + "selected_total_grants", + "snapshot_total_grants_max", + "apply_payload_grant_count", + ), + ) + if user_count is not None: + workload["memory_model_user_count"] = user_count + if repo_count is not None: + workload["memory_model_repo_count"] = repo_count + if grant_count is not None: + workload["memory_model_grant_count"] = grant_count + + +def max_workload_number( + workload: dict[str, int | float | str], field_names: Sequence[str] +) -> int | float | None: + """Return the largest numeric value found for the supplied workload fields.""" + values = [ + value + for field_name in field_names + if isinstance((value := workload.get(field_name)), int | float) + ] + return max(values) if values else None + + +def record_workload_max( + workload: dict[str, int | float | str], field_name: str, value: object +) -> None: + """Record the maximum numeric value for a named workload dimension.""" + if isinstance(value, bool) or not isinstance(value, int | float): + return + old_value = workload.get(field_name) + if not isinstance(old_value, int | float) or value > old_value: + workload[field_name] = value + + def artifact_sizes_for_run(log_path: Path) -> dict[str, int]: """Return sizes of JSON artifacts in the same run directory as the log.""" run_directory = log_path.parent @@ -1636,6 +2273,54 @@ def wait_for_jaeger_trace_collectors(collectors: list[JaegerTraceCollector]) -> collector.wait() +def graphql_query_metadata_from_record(record: dict[str, Any]) -> dict[str, Any] | None: + """Return correlation metadata from a structured `graphql_query` log record.""" + if record.get("event") != "graphql_query": + return None + metadata: dict[str, Any] = { + "span_id": record.get("span"), + "parent_span_id": record.get("parent_span"), + "trace_id": record.get("trace"), + } + phase = record.get("phase") + if phase == "start": + metadata["started_at"] = record.get("ts") + elif phase == "end": + metadata["ended_at"] = record.get("ts") + for field_name in ( + "cursor_present", + "duration_ms", + "error_type", + "graphql_client", + "page_number", + "page_size", + "query_bytes", + "query_name", + "response_fields", + "status", + "url", + "variable_names", + # Current src-py-lib logs variable names only. Keep these optional fields + # so raw trace artifacts automatically include values if the GraphQL log + # event grows an opt-in sanitized-variable field later. + "input_variables", + "variable_values", + "variables", + ): + if field_name in record: + metadata[field_name] = record[field_name] + return {key: value for key, value in metadata.items() if value is not None} + + +def graphql_query_span_key(record: dict[str, Any]) -> tuple[str, str] | None: + """Return the `(trace_id, span_id)` key for a GraphQL query log span.""" + trace_id = optional_string(record.get("trace")) + span_id = optional_string(record.get("span")) + if trace_id is None or span_id is None: + return None + return trace_id, span_id + + def graphql_trace_request_from_record(record: dict[str, Any]) -> dict[str, Any] | None: if record.get("event") != "http_request" or record.get("phase") != "end": return None @@ -1997,9 +2682,10 @@ def write_results_files( results: list[CommandResult], comparisons: list[CaseComparison], config: EndToEndConfig, + sourcegraph_load_monitor: SourcegraphLoadMonitor | None, ) -> None: if config.results_json is not None: - write_results_json(config.results_json, results, comparisons) + write_results_json(config.results_json, results, comparisons, sourcegraph_load_monitor) if config.results_csv is not None: write_results_csv(config.results_csv, results) phase_csv = phase_results_csv_path(config.results_csv) @@ -2010,12 +2696,20 @@ def write_results_json( path: Path, results: list[CommandResult], comparisons: list[CaseComparison], + sourcegraph_load_monitor: SourcegraphLoadMonitor | None, ) -> None: path.parent.mkdir(parents=True, exist_ok=True) + sourcegraph_monitor: dict[str, Any] | None = None + if sourcegraph_load_monitor is not None: + sourcegraph_monitor = { + "output_dir": str(sourcegraph_load_monitor.output_dir), + "log_path": str(sourcegraph_load_monitor.log_path), + } with path.open("w", encoding="utf-8") as output_file: json.dump( { "generated_at": datetime.datetime.now(datetime.UTC).isoformat(), + "sourcegraph_load_monitor": sourcegraph_monitor, "results": [result_to_json(result) for result in results], "comparisons": [comparison_to_json(comparison) for comparison in comparisons], }, @@ -2240,7 +2934,11 @@ def normalized_memory(result: CommandResult) -> dict[str, float]: if peak_rss_mb is None: return {} normalized: dict[str, float] = {} - for field_name in ("user_count", "total_users", "repo_count", "total_grants"): + for field_name in ( + "memory_model_user_count", + "memory_model_repo_count", + "memory_model_grant_count", + ): value = result.workload.get(field_name) if isinstance(value, int | float) and value > 0: normalized[f"peak_rss_mb_per_{field_name}"] = peak_rss_mb / float(value) diff --git a/dev/test-plan.md b/dev/test-plan.md index 521d165..8312541 100644 --- a/dev/test-plan.md +++ b/dev/test-plan.md @@ -118,6 +118,71 @@ These numbers don't have published baselines yet; this run *creates* them. The deliverable is "we now know `--set --apply` with `--parallelism 16` hits N MB RSS and W seconds of snapshot wall-clock at G grants." +### Memory-per-grant model + +Generate exact users × repos maps and, when ready, run them through the CLI: + +```bash +uv run python dev/run-memory-model-sweep.py + +uv run python dev/run-memory-model-sweep.py \ + --run \ + --parallelism 1 +``` + +The runner writes generated maps and `results.json` under +`src-auth-perms-sync-runs//memory-model-sweep//`. +It uses an inventory-aware `--cases auto` sweep and dry-run mode by default. +On an instance with 1K+ visible repos, `auto` includes repo-axis points up to +1K repos and mixed cases up to 100K planned grants. Use explicit cases for +larger stress points, and use `--mode apply-no-backup --allow-apply` only on a +scratch instance: + +```bash +uv run python dev/run-memory-model-sweep.py \ + --cases '1x1,10000x1,1x1000,100x1000,1000x1000,10000x1000' \ + --run \ + --parallelism 1 +``` + +Fit memory from repeated e2e JSON results instead of dividing one run's +`peak_rss_mb` by one run's grants: + +```bash +uv run python dev/analyze-memory.py results/*.json \ + --command set_full \ + --case-regex 'set-full' \ + --features users,repos,grants \ + --estimate-users 10000 \ + --estimate-repos 100 +``` + +The analyzer fits: + +```text +peak RSS MiB = intercept + users*b1 + repos*b2 + grants*b3 +``` + +Use one command mode per fit (`set_full` with backup, `set_full --no-backup`, +`restore`, etc.). Mixing modes smears fixed snapshot / apply costs into the +per-grant coefficient. + +On the sgdev test instance with 10,001 users and 1,023 visible repos, a +dry-run `10000x1000` case planned 10M grants. Before the lazy-union planner, +it measured about 651 MiB peak RSS; after Phase 1 in +[mapping-efficiency.md](./mapping-efficiency.md), the same case measured about +68 MiB. Re-measure after meaningful mapping or snapshot changes; these numbers +describe dry-run planning memory, not apply mutation throughput. + +The e2e `workload` object now uses event-aware names. In older result JSON, +`total_users: 40004` came from `apply_username_overwrites` and meant "username +entries in mutation payloads" (`4 mutated repos × 10001 users`), not total +Sourcegraph users. Likewise `repo_count: 575` came from a batch fetch and meant +"grant rows fetched for 25 users" (`25 × 23`), not distinct repos. New results +expose those as `apply_payload_grant_count` and +`batch_fetched_grant_count_max`, plus canonical `memory_model_user_count`, +`memory_model_repo_count`, and `memory_model_grant_count` fields for modeling. + --- ## Failure injection (scenario e) diff --git a/maps-example.yaml b/maps-example.yaml index c854791..e5dfb7e 100644 --- a/maps-example.yaml +++ b/maps-example.yaml @@ -1,26 +1,37 @@ # Auth provider → code host connection mapping rules # Maintain this file using auth-providers.yaml and code-hosts.yaml as references. # Those files are generated under src-auth-perms-sync-runs//. +# +# These examples cover every supported filter field: +# - users.authProvider: clientID, configID, displayName, samlGroup, serviceID, type +# - users.emails (verified email addresses) +# - users.usernames +# - repos.codeHostConnection: config, displayName, id, kind, url +# - repos.names +# - repos.regexes maps: -- name: All users from Line of Business 1 - User Group 1 get access to all repos synced from service account 1 +- name: SAML group users get all repos synced from one service account users: authProvider: + configID: okta samlGroup: LOB1-GROUP1 + type: saml repos: codeHostConnection: config: username: LOB1-SA1 -- name: All users from Line of Business 1 - User Group 2 get access to all repos synced from service account 2 +- name: Users from one exact auth provider get repos from one exact code host connection users: authProvider: - samlGroup: LOB1-GROUP2 + clientID: sourcegraph + displayName: Okta SAML + serviceID: https://idp.example.com/saml repos: codeHostConnection: - config: - username: LOB1-SA2 + id: 12 - name: All Okta SAML users get access to all Bitbucket repos users: @@ -31,9 +42,36 @@ maps: codeHostConnection: kind: BITBUCKETSERVER +- name: All builtin users get repos from the GitHub Cloud connection + users: + authProvider: + type: builtin + repos: + codeHostConnection: + displayName: GitHub Cloud + +- name: All builtin users get repos from the GitHub URL connection + users: + authProvider: + type: builtin + repos: + codeHostConnection: + url: https://github.com/ + +- name: Exact user gets named repos + users: + emails: + - alice@example.com + - bob@example.com + repos: + names: + - github.com/example/private-repo + - name: All builtin users get access to all repos under the github.com/example org, from any code host connection users: authProvider: type: builtin repos: - regex: https://github.com/example/.* + regexes: + - ^github\.com/example/.* + - ^gitlab\.com/example/.* diff --git a/pyproject.toml b/pyproject.toml index bd63e95..28c43d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ dependencies = [ "json5>=0.14.0", "pyyaml>=6.0.3", - "src-py-lib==0.1.5", + "src-py-lib==0.1.6", ] keywords = [ "Sourcegraph" diff --git a/src/src_auth_perms_sync/cli.py b/src/src_auth_perms_sync/cli.py index 8761abb..7e25143 100644 --- a/src/src_auth_perms_sync/cli.py +++ b/src/src_auth_perms_sync/cli.py @@ -192,6 +192,14 @@ class SrcAuthPermissionsSyncConfig(src.SourcegraphClientConfig, src.LoggingConfi ge=1, help="Max attempts per HTTP request before giving up (default: 5)", ) + http_timeout_seconds: float = src.config_field( + default=60.0, + env_var="SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS", + cli_flag="--http-timeout-seconds", + metavar="SECONDS", + gt=0, + help="HTTP read timeout per request in seconds (default: 60)", + ) sample_interval: float = src.config_field( default=10.0, env_var="SRC_AUTH_PERMS_SYNC_SAMPLE_INTERVAL", @@ -387,6 +395,7 @@ def run_fields( "explicit_permissions_batch_size": config.explicit_permissions_batch_size, "trace": config.trace, "max_attempts": config.max_attempts, + "http_timeout_seconds": config.http_timeout_seconds, "no_backup": config.no_backup, "sample_interval": config.sample_interval, "user_created_after": config.created_after, @@ -404,6 +413,7 @@ def run_with_client( ) -> None: """Create a client, run the selected command, and always close HTTP resources.""" http = src.HTTPClient( + timeout=config.http_timeout_seconds, user_agent="src-auth-perms-sync/0.1 (+python)", max_attempts=config.max_attempts, max_connections=config.parallelism, diff --git a/src/src_auth_perms_sync/permissions/apply.py b/src/src_auth_perms_sync/permissions/apply.py index 7849855..14969ba 100644 --- a/src/src_auth_perms_sync/permissions/apply.py +++ b/src/src_auth_perms_sync/permissions/apply.py @@ -301,12 +301,12 @@ def _apply_repo_overwrite_plans( ) -> shared_types.MutationCounts: """Dispatch per-repo overwrite mutations with bounded in-flight work.""" max_pending_futures = max(1, parallelism * 2) - total_users = sum(len(overwrite.usernames) for overwrite in overwrites) + payload_grant_count = sum(len(overwrite.usernames) for overwrite in overwrites) with src.event( "apply_username_overwrites", payload_count=len(overwrites), parallelism=parallelism, - total_users=total_users, + payload_grant_count=payload_grant_count, max_pending_futures=max_pending_futures, ) as batch_event: succeeded = 0 diff --git a/src/src_auth_perms_sync/permissions/command.py b/src/src_auth_perms_sync/permissions/command.py index 95d4041..fbdc00c 100644 --- a/src/src_auth_perms_sync/permissions/command.py +++ b/src/src_auth_perms_sync/permissions/command.py @@ -386,7 +386,14 @@ def cmd_set_additive_user( context = load_mapping_context(client, input_path, saml_groups_attribute_name_by_config_id) if context is None: return run_context.CommandData() - user = _resolve_user_identifier(client, user_identifier) + include_user_emails = permissions_mapping.mapping_rules_need_user_emails( + context.mapping_rules + ) + user = _resolve_user_identifier( + client, + user_identifier, + include_emails=include_user_emails, + ) if user_created_after is not None: candidate_user_ids = user_ids_created_on_or_after(client, user_created_after) if user["id"] not in candidate_user_ids: @@ -446,6 +453,9 @@ def cmd_set_additive_users_without_explicit_perms( context = load_mapping_context(client, input_path, saml_groups_attribute_name_by_config_id) if context is None: return run_context.CommandData() + include_user_emails = permissions_mapping.mapping_rules_need_user_emails( + context.mapping_rules + ) resolved_mappings = resolve_additive_mappings(context) candidates = permissions_sourcegraph.list_site_user_candidates(client, created_after_filter) log.info("Received %d non-deleted user candidate(s).", len(candidates)) @@ -455,7 +465,11 @@ def cmd_set_additive_users_without_explicit_perms( for candidate in candidates: if permissions_sourcegraph.user_has_explicit_repos(client, candidate["id"]): continue - user = permissions_sourcegraph.get_user_by_id(client, candidate["id"]) + user = permissions_sourcegraph.get_user_by_id( + client, + candidate["id"], + include_emails=include_user_emails, + ) if user is None: log.warning( "Skipping user candidate %s: user no longer exists.", @@ -492,18 +506,33 @@ def cmd_set_additive_users_without_explicit_perms( def _resolve_user_identifier( - client: src.SourcegraphClient, user_identifier: str + client: src.SourcegraphClient, + user_identifier: str, + *, + include_emails: bool = False, ) -> shared_types.User: """Resolve username/email input to one Sourcegraph user.""" user: shared_types.User | None if "@" in user_identifier: user = permissions_sourcegraph.get_user_by_email( - client, user_identifier - ) or permissions_sourcegraph.get_user_by_username(client, user_identifier) + client, + user_identifier, + include_emails=include_emails, + ) or permissions_sourcegraph.get_user_by_username( + client, + user_identifier, + include_emails=include_emails, + ) else: user = permissions_sourcegraph.get_user_by_username( - client, user_identifier - ) or permissions_sourcegraph.get_user_by_email(client, user_identifier) + client, + user_identifier, + include_emails=include_emails, + ) or permissions_sourcegraph.get_user_by_email( + client, + user_identifier, + include_emails=include_emails, + ) if user is None: raise SystemExit(f"No Sourcegraph user found for {user_identifier!r}.") if user["username"] != user_identifier: diff --git a/src/src_auth_perms_sync/permissions/full_set.py b/src/src_auth_perms_sync/permissions/full_set.py index 1ef3fda..6b79458 100644 --- a/src/src_auth_perms_sync/permissions/full_set.py +++ b/src/src_auth_perms_sync/permissions/full_set.py @@ -98,6 +98,7 @@ def _capture_full_set_snapshot_state( explicit_permissions_batch_size: int, bind_id_mode: str, worker_pool: ThreadPoolExecutor | None = None, + include_user_emails: bool = False, ) -> _FullSetUserState: """Load users while capturing the before-snapshot.""" total_users = shared_sourcegraph.count_users(client) @@ -110,7 +111,11 @@ def _capture_full_set_snapshot_state( before_timestamp = backups.backup_timestamp() before_snapshot = permission_snapshot.build_snapshot( client, - shared_sourcegraph.list_users_streaming(client, collect_into=users), + shared_sourcegraph.list_users_streaming( + client, + collect_into=users, + include_emails=include_user_emails, + ), parallelism, bind_id_mode, input_path, @@ -140,6 +145,7 @@ def _load_full_set_snapshot_state( bind_id_mode: str, capture_before: bool, worker_pool: ThreadPoolExecutor | None = None, + include_user_emails: bool = False, ) -> _FullSetUserState: """Load all users, optionally with a before-snapshot.""" if capture_before: @@ -150,10 +156,14 @@ def _load_full_set_snapshot_state( explicit_permissions_batch_size, bind_id_mode, worker_pool, + include_user_emails=include_user_emails, ) log.info("Loading users from %s ...", client.endpoint) - users = shared_sourcegraph.list_users_with_accounts(client) + users = shared_sourcegraph.list_users_with_accounts( + client, + include_emails=include_user_emails, + ) log.info("Received %d total users.", len(users)) return _FullSetUserState(users=users) @@ -256,12 +266,13 @@ def _write_noop_full_set_snapshots( return before_path, after_path, diff_path, maps_backup_path -def _plan_full_set_permissions( +def plan_full_set_permissions( context: permission_types.MappingContext, users: list[shared_types.User], ) -> _FullSetPlan: """Resolve mapping rules into one repo-to-users overwrite plan.""" - repo_usernames: dict[str, set[str]] = {} + expected_users: dict[str, tuple[str, ...]] = {} + union_usernames_by_repo_id: dict[str, set[str]] = {} repo_names: dict[str, str] = {} for mapping_index, mapping in enumerate(context.mapping_rules, start=1): @@ -293,15 +304,28 @@ def _plan_full_set_permissions( log.warning(" No repos matched — skipping rule.") continue - matched_usernames = tuple(user["username"] for user in matched_users) + matched_usernames = tuple(sorted({user["username"] for user in matched_users})) for repo in matched_repos: - bucket = repo_usernames.setdefault(repo["id"], set()) - repo_names[repo["id"]] = repo["name"] - bucket.update(matched_usernames) + repo_id = repo["id"] + repo_names[repo_id] = repo["name"] + union_usernames = union_usernames_by_repo_id.get(repo_id) + if union_usernames is not None: + union_usernames.update(matched_usernames) + continue + + existing_usernames = expected_users.get(repo_id) + if existing_usernames is not None: + union_usernames = set(existing_usernames) + union_usernames.update(matched_usernames) + union_usernames_by_repo_id[repo_id] = union_usernames + del expected_users[repo_id] + continue + + expected_users[repo_id] = matched_usernames + + for repo_id, usernames in union_usernames_by_repo_id.items(): + expected_users[repo_id] = tuple(sorted(usernames)) - expected_users = { - repo_id: tuple(sorted(usernames)) for repo_id, usernames in repo_usernames.items() - } total_grants = sum(len(usernames) for usernames in expected_users.values()) if expected_users: log.info( @@ -656,6 +680,7 @@ def _load_full_set_plan( retain_saml_group_users: bool, worker_pool: ThreadPoolExecutor | None = None, ) -> _FullSetLoadedPlan: + include_user_emails = permissions_mapping.mapping_rules_need_user_emails(mapping_rules) user_state = _load_full_set_snapshot_state( client, input_path, @@ -664,6 +689,7 @@ def _load_full_set_plan( bind_id_mode, capture_before=capture_before, worker_pool=worker_pool, + include_user_emails=include_user_emails, ) before_path: Path | None = None if capture_before: @@ -687,7 +713,7 @@ def _load_full_set_plan( user_state.users, user_created_after, ) - plan = _plan_full_set_permissions(context, users) + plan = plan_full_set_permissions(context, users) snapshot_state = _compact_full_set_snapshot_state(user_state, users) saml_group_users = ( saml_groups.compact_saml_group_users( diff --git a/src/src_auth_perms_sync/permissions/mapping.py b/src/src_auth_perms_sync/permissions/mapping.py index 904a63e..33a9fb0 100644 --- a/src/src_auth_perms_sync/permissions/mapping.py +++ b/src/src_auth_perms_sync/permissions/mapping.py @@ -1,11 +1,11 @@ """Permission mapping resolution: validate rules and match users/repos. Each mapping rule has a `users:` section and a `repos:` section, each -containing one or more matchers (today: `authProvider`, -`codeHostConnection`, and `regex`). Within a matcher, the supplied -keys AND together against the discovered auth-provider / external- -service entries. Across mapping rules, `cmd_set` unions the per-repo -user sets at apply time — see `src/src_auth_perms_sync/permissions/types.py` for the rationale. +containing one or more matchers. Within a matcher, the supplied keys +AND together against the discovered auth-provider / external-service +entries. Across sibling matchers, results intersect. Across mapping +rules, `cmd_set` unions the per-repo user sets at apply time — see +`src/src_auth_perms_sync/permissions/types.py` for the rationale. Adding a new matcher type: @@ -103,7 +103,13 @@ def validate_mapping_rules(rules: list[permission_types.MappingRule]) -> None: ) -_KNOWN_USER_MATCHERS: set[str] = {"authProvider"} +_KNOWN_USER_MATCHERS: set[str] = {"authProvider", "emails", "usernames"} +_KNOWN_REPO_MATCHERS: set[str] = {"codeHostConnection", "names", "regexes"} + + +def mapping_rules_need_user_emails(mapping_rules: list[permission_types.MappingRule]) -> bool: + """Return whether any mapping rule filters users by verified email.""" + return any("emails" in mapping["users"] for mapping in mapping_rules) def _validate_users_section(section: dict[str, object], prefix: str) -> list[str]: @@ -123,6 +129,10 @@ def _validate_users_section(section: dict[str, object], prefix: str) -> list[str ) if "samlGroup" in auth_provider: errors.extend(_validate_saml_group(auth_provider, prefix)) + if "emails" in section: + errors.extend(_validate_string_list(section["emails"], prefix, "users.emails")) + if "usernames" in section: + errors.extend(_validate_string_list(section["usernames"], prefix, "users.usernames")) return errors @@ -157,7 +167,7 @@ def _validate_repos_section(section: dict[str, object], prefix: str) -> list[str """Reject unknown matcher keys and validate `codeHostConnection:` shape.""" errors: list[str] = [] for key in section: - if key not in {"codeHostConnection", "regex"}: + if key not in _KNOWN_REPO_MATCHERS: errors.append(f"{prefix}: unknown repos matcher {key!r}") code_host_section = cast(dict[str, object] | None, section.get("codeHostConnection")) if code_host_section is not None: @@ -190,17 +200,46 @@ def _validate_repos_section(section: dict[str, object], prefix: str) -> list[str f"key/value pairs to deep-subset-match against the service's " f"parsed config (got {type(code_host_section['config']).__name__})" ) - regex = section.get("regex") - if regex is not None: - if not isinstance(regex, str): - errors.append(f"{prefix}: repos.regex must be a string (got {type(regex).__name__})") - elif not regex: - errors.append(f"{prefix}: repos.regex is an empty string") - else: - try: - re.compile(regex) - except re.error as exception: - errors.append(f"{prefix}: repos.regex is not a valid Python regex: {exception}") + if "names" in section: + errors.extend(_validate_string_list(section["names"], prefix, "repos.names")) + regexes = section.get("regexes") + if regexes is not None: + errors.extend(_validate_regexes(regexes, prefix)) + return errors + + +def _validate_regexes(value: object, prefix: str) -> list[str]: + """Validate list-based regex filters.""" + errors = _validate_string_list(value, prefix, "repos.regexes") + if errors: + return errors + + for index, pattern in enumerate(cast(list[str], value)): + try: + re.compile(pattern) + except re.error as exception: + errors.append( + f"{prefix}: repos.regexes[{index}] is not a valid Python regex: {exception}" + ) + return errors + + +def _validate_string_list(value: object, prefix: str, path: str) -> list[str]: + """Validate list-based exact-match filters.""" + if not isinstance(value, list): + return [f"{prefix}: {path} must be a list of strings (got {type(value).__name__})"] + + items = cast(list[object], value) + errors: list[str] = [] + if not items: + errors.append(f"{prefix}: {path} is empty (matches nothing)") + for index, item in enumerate(items): + if not isinstance(item, str): + errors.append( + f"{prefix}: {path}[{index}] must be a string (got {type(item).__name__} {item!r})" + ) + elif not item: + errors.append(f"{prefix}: {path}[{index}] is an empty string") return errors @@ -243,6 +282,15 @@ def resolve_users( saml_groups_attribute_names, ) } + elif key == "emails": + current_ids = { + user["id"] for user in _users_matching_emails(cast(list[str], matcher), all_users) + } + elif key == "usernames": + current_ids = { + user["id"] + for user in _users_matching_usernames(cast(list[str], matcher), all_users) + } else: # validate_mapping_rules catches this earlier with a clearer # message; this only fires for programmatic callers. @@ -273,6 +321,12 @@ def user_matches_users_section( saml_groups_attribute_names, ): return False + elif key == "emails": + if not _user_matches_emails(user, cast(list[str], matcher)): + return False + elif key == "usernames": + if user["username"] not in cast(list[str], matcher): + return False else: # validate_mapping_rules catches this earlier with a clearer # message; this only fires for programmatic callers. @@ -280,6 +334,38 @@ def user_matches_users_section( return True +def _users_matching_emails( + emails: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users with at least one verified email in `emails`.""" + matched = [user for user in all_users if _user_matches_emails(user, emails)] + log.info(" emails → %d user(s) matched %d email(s)", len(matched), len(set(emails))) + return matched + + +def _user_matches_emails(user: shared_types.User, emails: list[str]) -> bool: + """Match only verified emails, mirroring Sourcegraph's `user(email:)` lookup.""" + email_set = set(emails) + return any( + user_email["verified"] and user_email["email"] in email_set + for user_email in user.get("emails", []) + ) + + +def _users_matching_usernames( + usernames: list[str], all_users: list[shared_types.User] +) -> list[shared_types.User]: + """Return users whose Sourcegraph username is listed exactly.""" + username_set = set(usernames) + matched = [user for user in all_users if user["username"] in username_set] + log.info( + " usernames → %d user(s) matched %d username(s)", + len(matched), + len(username_set), + ) + return matched + + def _users_matching_auth_provider( matcher: permission_types.AuthProviderMatcher, all_users: list[shared_types.User], @@ -449,7 +535,7 @@ def resolve_repos( matched_ids: set[str] | None = None repo_index: dict[str, permission_types.Repository] = {} - ordered_keys = [key for key in ("codeHostConnection", "regex") if key in section] + ordered_keys = [key for key in ("codeHostConnection", "names", "regexes") if key in section] for key in ordered_keys: matcher = section[key] if key == "codeHostConnection": @@ -458,13 +544,20 @@ def resolve_repos( services_by_id, repos_by_external_service_id, ) - elif key == "regex": + elif key == "names": + candidate_repos = ( + [repo_index[repo_id] for repo_id in matched_ids] + if matched_ids is not None + else list(all_repos_by_id.values()) + ) + repos = _repos_matching_names(cast(list[str], matcher), candidate_repos) + elif key == "regexes": candidate_repos = ( [repo_index[repo_id] for repo_id in matched_ids] if matched_ids is not None else list(all_repos_by_id.values()) ) - repos = _repos_matching_regex(cast(str, matcher), candidate_repos) + repos = _repos_matching_regexes(cast(list[str], matcher), candidate_repos) else: # validate_mapping_rules catches this earlier with a clearer # message; this only fires for programmatic callers. @@ -479,6 +572,16 @@ def resolve_repos( return [repo_index[repo_id] for repo_id in matched_ids] +def _repos_matching_names( + names: list[str], repos: list[permission_types.Repository] +) -> list[permission_types.Repository]: + """Return repos whose Sourcegraph name is listed exactly.""" + name_set = set(names) + matched = [repo for repo in repos if repo["name"] in name_set] + log.info(" names → %d repo(s) matched %d name(s)", len(matched), len(name_set)) + return matched + + def _repos_matching_code_host_connection( matcher: permission_types.CodeHostConnectionMatcher, services_by_id: dict[int, permission_types.ExternalService], @@ -505,22 +608,26 @@ def _repos_matching_code_host_connection( return list(matched_repos.values()) -def _repos_matching_regex( - pattern: str, repos: list[permission_types.Repository] +def _repos_matching_regexes( + patterns: list[str], repos: list[permission_types.Repository] ) -> list[permission_types.Repository]: - """Return repos whose name matches `pattern` using Python `re`. + """Return repos whose name matches any pattern using Python `re`. Sourcegraph repo names usually omit the URL scheme (for example `github.com/example/repo`). To keep URL-looking operator patterns useful, also test `https://`. """ - compiled = re.compile(pattern) + compiled_patterns = [re.compile(pattern) for pattern in patterns] matched = [ repo for repo in repos - if compiled.search(repo["name"]) or compiled.search(f"https://{repo['name']}") + if any( + compiled_pattern.search(repo["name"]) + or compiled_pattern.search(f"https://{repo['name']}") + for compiled_pattern in compiled_patterns + ) ] - log.info(" regex → %d repo(s) matched %r", len(matched), pattern) + log.info(" regexes → %d repo(s) matched %d pattern(s)", len(matched), len(patterns)) return matched diff --git a/src/src_auth_perms_sync/permissions/queries.py b/src/src_auth_perms_sync/permissions/queries.py index afa83b5..7a7e2ec 100644 --- a/src/src_auth_perms_sync/permissions/queries.py +++ b/src/src_auth_perms_sync/permissions/queries.py @@ -89,32 +89,57 @@ } """ -QUERY_USER_BY_USERNAME = f""" +USER_EMAIL_FIELDS = """ +emails { + email + verified +} +""" + + +def user_fields(*, include_emails: bool = False) -> str: + """Return user fields, adding emails only when downstream matching needs them.""" + if include_emails: + return f"{USER_FIELDS}\n{USER_EMAIL_FIELDS}" + return USER_FIELDS + + +def query_user_by_username(*, include_emails: bool = False) -> str: + return f""" query UserByUsername($username: String!) {{ user(username: $username) {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} """ -QUERY_USER_BY_EMAIL = f""" + +def query_user_by_email(*, include_emails: bool = False) -> str: + return f""" query UserByEmail($email: String!) {{ user(email: $email) {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} """ -QUERY_USER_BY_ID = f""" + +def query_user_by_id(*, include_emails: bool = False) -> str: + return f""" query UserByID($id: ID!) {{ node(id: $id) {{ ... on User {{ - {USER_FIELDS} + {user_fields(include_emails=include_emails)} }} }} }} """ + +QUERY_USER_BY_USERNAME = query_user_by_username() +QUERY_USER_BY_EMAIL = query_user_by_email() +QUERY_USER_BY_ID = query_user_by_id() + QUERY_SITE_USERS = """ query SiteUsers($limit: Int!, $offset: Int!, $createdAt: SiteUsersDateRangeInput) { site { diff --git a/src/src_auth_perms_sync/permissions/snapshot.py b/src/src_auth_perms_sync/permissions/snapshot.py index ae3928c..3073205 100644 --- a/src/src_auth_perms_sync/permissions/snapshot.py +++ b/src/src_auth_perms_sync/permissions/snapshot.py @@ -230,7 +230,7 @@ def _fetch( user["username"]: repository_ids_by_user_id.get(user["id"], []) for user in batch_users } - fetch_event["repo_count"] = sum( + fetch_event["fetched_grant_count"] = sum( len(repository_ids) for repository_ids in repository_ids_by_username.values() ) fetch_event["per_user_failures"] = failures diff --git a/src/src_auth_perms_sync/permissions/sourcegraph.py b/src/src_auth_perms_sync/permissions/sourcegraph.py index c64c4b4..ed4e32b 100644 --- a/src/src_auth_perms_sync/permissions/sourcegraph.py +++ b/src/src_auth_perms_sync/permissions/sourcegraph.py @@ -41,29 +41,53 @@ def list_repos_for_external_service( ] -def get_user_by_username(client: src.SourcegraphClient, username: str) -> shared_types.User | None: +def get_user_by_username( + client: src.SourcegraphClient, + username: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Return the exact Sourcegraph user for `username`, if it exists.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_USERNAME, cast(src.JSONDict, {"username": username})), + client.graphql( + queries.query_user_by_username(include_emails=include_emails), + cast(src.JSONDict, {"username": username}), + ), ) return cast(shared_types.User | None, data.get("user")) -def get_user_by_email(client: src.SourcegraphClient, email: str) -> shared_types.User | None: +def get_user_by_email( + client: src.SourcegraphClient, + email: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Return the user owning the verified email address, if it exists.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_EMAIL, cast(src.JSONDict, {"email": email})), + client.graphql( + queries.query_user_by_email(include_emails=include_emails), + cast(src.JSONDict, {"email": email}), + ), ) return cast(shared_types.User | None, data.get("user")) -def get_user_by_id(client: src.SourcegraphClient, user_id: str) -> shared_types.User | None: +def get_user_by_id( + client: src.SourcegraphClient, + user_id: str, + *, + include_emails: bool = False, +) -> shared_types.User | None: """Hydrate a User node by GraphQL ID.""" data = cast( dict[str, Any], - client.graphql(queries.QUERY_USER_BY_ID, cast(src.JSONDict, {"id": user_id})), + client.graphql( + queries.query_user_by_id(include_emails=include_emails), + cast(src.JSONDict, {"id": user_id}), + ), ) return cast(shared_types.User | None, data.get("node")) @@ -171,11 +195,10 @@ def list_users_explicit_repo_ids( repository_ids_by_user_id: dict[str, list[str]] = {user_id: [] for user_id in user_ids} pending_pages: list[tuple[str, str | None]] = [(user_id, None) for user_id in user_ids] - graphql_client = _graphql_client_without_auto_pagination(client) while pending_pages: batch = pending_pages[:batch_size] del pending_pages[:batch_size] - data = graphql_client.execute( + data = client.graphql( _user_explicit_repos_batch_query(len(batch)), _user_explicit_repos_batch_variables(batch), follow_pages=False, @@ -237,15 +260,6 @@ def list_repositories_by_ids( return repositories -def _graphql_client_without_auto_pagination(client: src.SourcegraphClient) -> src.GraphQLClient: - return src.GraphQLClient( - url=f"{client.endpoint}/.api/graphql", - headers={"Authorization": f"token {client.token}"}, - label="Sourcegraph", - http=client.http, - ) - - def _batches(values: Sequence[str], batch_size: int) -> Iterator[Sequence[str]]: for start_index in range(0, len(values), batch_size): yield values[start_index : start_index + batch_size] diff --git a/src/src_auth_perms_sync/permissions/types.py b/src/src_auth_perms_sync/permissions/types.py index f57ffab..a320892 100644 --- a/src/src_auth_perms_sync/permissions/types.py +++ b/src/src_auth_perms_sync/permissions/types.py @@ -85,11 +85,14 @@ class CodeHostConnectionMatcher(TypedDict, total=False): class UsersFilter(TypedDict, total=False): authProvider: AuthProviderMatcher + emails: list[str] + usernames: list[str] class ReposFilter(TypedDict, total=False): codeHostConnection: CodeHostConnectionMatcher - regex: str + names: list[str] + regexes: list[str] class MappingRule(TypedDict): diff --git a/src/src_auth_perms_sync/shared/queries.py b/src/src_auth_perms_sync/shared/queries.py index 4ffb1e4..c833e42 100644 --- a/src/src_auth_perms_sync/shared/queries.py +++ b/src/src_auth_perms_sync/shared/queries.py @@ -38,15 +38,25 @@ } """ -QUERY_USERS = """ -query ListUsers($first: Int!, $after: String) { - users(first: $first, after: $after) { - nodes { +USER_EMAIL_FIELDS = """ emails { + email + verified + } +""" + + +def query_users(*, include_emails: bool = False) -> str: + """Return the users page query, adding email fields only when requested.""" + email_fields = USER_EMAIL_FIELDS if include_emails else "" + return f""" +query ListUsers($first: Int!, $after: String) {{ + users(first: $first, after: $after) {{ + nodes {{ id username builtinAuth - externalAccounts(first: 50) { - nodes { +{email_fields} externalAccounts(first: 50) {{ + nodes {{ serviceType serviceID clientID @@ -56,10 +66,13 @@ # Admin. Returns null for serviceType where the resolver does # not expose data (e.g. plain GitHub OAuth without SSO). accountData - } - } - } - pageInfo { hasNextPage endCursor } - } -} + }} + }} + }} + pageInfo {{ hasNextPage endCursor }} + }} +}} """ + + +QUERY_USERS = query_users() diff --git a/src/src_auth_perms_sync/shared/sourcegraph.py b/src/src_auth_perms_sync/shared/sourcegraph.py index 3b39d94..f138e2d 100644 --- a/src/src_auth_perms_sync/shared/sourcegraph.py +++ b/src/src_auth_perms_sync/shared/sourcegraph.py @@ -32,11 +32,15 @@ def count_users(client: src.SourcegraphClient) -> int: return cast(int, data["users"]["totalCount"]) -def list_users_with_accounts(client: src.SourcegraphClient) -> list[shared_types.User]: +def list_users_with_accounts( + client: src.SourcegraphClient, + *, + include_emails: bool = False, +) -> list[shared_types.User]: return [ cast(shared_types.User, node) for node in client.stream_connection_nodes( - queries.QUERY_USERS, + queries.query_users(include_emails=include_emails), connection_path=("users",), page_size=DEFAULT_PAGE_SIZE, ) @@ -46,6 +50,8 @@ def list_users_with_accounts(client: src.SourcegraphClient) -> list[shared_types def list_users_streaming( client: src.SourcegraphClient, collect_into: list[shared_types.User] | None = None, + *, + include_emails: bool = False, ) -> Iterator[shared_types.User]: """Stream ListUsers pages one at a time, yielding each User as it arrives. @@ -59,7 +65,7 @@ def list_users_streaming( streaming benefit in one pass — no double-pagination. """ for node in client.stream_connection_nodes( - queries.QUERY_USERS, + queries.query_users(include_emails=include_emails), connection_path=("users",), page_size=DEFAULT_PAGE_SIZE, ): diff --git a/src/src_auth_perms_sync/shared/types.py b/src/src_auth_perms_sync/shared/types.py index 9429a41..7ac9096 100644 --- a/src/src_auth_perms_sync/shared/types.py +++ b/src/src_auth_perms_sync/shared/types.py @@ -30,11 +30,17 @@ class ExternalAccountConnection(TypedDict): nodes: list[ExternalAccount] +class UserEmail(TypedDict): + email: str + verified: bool + + class User(TypedDict): id: str username: str builtinAuth: bool externalAccounts: ExternalAccountConnection + emails: NotRequired[list[UserEmail]] @dataclass(frozen=True, slots=True) diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py index fe6ef17..7c47f77 100644 --- a/tests/unit/test_cli_config.py +++ b/tests/unit/test_cli_config.py @@ -180,6 +180,15 @@ def test_explicit_permissions_batch_size_rejects_values_below_one(self) -> None: with self.assertRaisesRegex(shared_config.ConfigError, "greater than or equal to 1"): load_config_from_env(SRC_AUTH_PERMS_SYNC_EXPLICIT_PERMISSIONS_BATCH_SIZE="0") + def test_http_timeout_config_is_loaded_from_env(self) -> None: + config = load_config_from_env(SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS="90") + + self.assertEqual(90, config.http_timeout_seconds) + + def test_http_timeout_rejects_values_at_or_below_zero(self) -> None: + with self.assertRaisesRegex(shared_config.ConfigError, "greater than 0"): + load_config_from_env(SRC_AUTH_PERMS_SYNC_HTTP_TIMEOUT_SECONDS="0") + def test_trace_config_is_loaded_from_env(self) -> None: config = load_config_from_env(SRC_AUTH_PERMS_SYNC_TRACE="true") @@ -212,6 +221,33 @@ def capture_client( self.assertEqual(1, len(captured_clients)) self.assertTrue(captured_clients[0].trace) + def test_run_with_client_uses_configured_http_timeout(self) -> None: + configuration = make_config(http_timeout_seconds=75.0) + command = cli.resolve_command(configuration) + captured_clients: list[src.SourcegraphClient] = [] + + def capture_client( + _config: cli.SrcAuthPermissionsSyncConfig, + _command: cli.ResolvedCommand, + client: src.SourcegraphClient, + _worker_pool: ThreadPoolExecutor, + ) -> None: + captured_clients.append(client) + + with ( + ThreadPoolExecutor(max_workers=1) as worker_pool, + mock.patch.object(cli, "run_command", side_effect=capture_client), + ): + cli.run_with_client( + configuration, + command, + "https://sourcegraph.example.com", + worker_pool, + ) + + self.assertEqual(1, len(captured_clients)) + self.assertEqual(75.0, captured_clients[0].http.timeout) + def test_validate_config_rejects_multiple_set_modes(self) -> None: self.assert_config_error( make_config(set_path=Path("maps.yaml"), full=True, user="alice"), @@ -249,6 +285,7 @@ def test_run_fields_include_concrete_command(self) -> None: self.assertEqual(True, fields["apply_flag"]) self.assertEqual(25, fields["explicit_permissions_batch_size"]) self.assertEqual(False, fields["trace"]) + self.assertEqual(60.0, fields["http_timeout_seconds"]) def test_run_command_passes_primary_data_to_combined_sync(self) -> None: configuration = make_config(get=True, sync_saml_organizations=True) diff --git a/tests/unit/test_maps.py b/tests/unit/test_maps.py index 34bb483..3dcbe77 100644 --- a/tests/unit/test_maps.py +++ b/tests/unit/test_maps.py @@ -1,12 +1,18 @@ from __future__ import annotations +import base64 +import itertools import tempfile import unittest from pathlib import Path +from typing import cast import yaml -from src_auth_perms_sync.permissions import maps +from src_auth_perms_sync.permissions import full_set, mapping, maps +from src_auth_perms_sync.permissions import queries as permission_queries +from src_auth_perms_sync.permissions import types as permission_types +from src_auth_perms_sync.shared import queries as shared_queries from src_auth_perms_sync.shared import types as shared_types @@ -73,3 +79,379 @@ def test_count_users_per_provider_counts_each_user_once_per_provider(self) -> No self.assertEqual(1, counts[maps.BUILTIN_PROVIDER_KEY]) self.assertEqual(1, counts[("saml", "https://idp.example.com", "sourcegraph")]) self.assertEqual(1, counts[("github", "https://github.com/", "github-client")]) + + +class MappingTests(unittest.TestCase): + def test_mapping_rules_need_user_emails_tracks_email_filters(self) -> None: + rules_without_email_filters = cast( + list[permission_types.MappingRule], + [ + { + "users": {"usernames": ["alice"]}, + "repos": {"names": ["github.com/example/private-repo"]}, + } + ], + ) + rules_with_email_filters = cast( + list[permission_types.MappingRule], + [ + { + "users": {"emails": ["alice@example.com"]}, + "repos": {"names": ["github.com/example/private-repo"]}, + } + ], + ) + + self.assertFalse(mapping.mapping_rules_need_user_emails(rules_without_email_filters)) + self.assertTrue(mapping.mapping_rules_need_user_emails(rules_with_email_filters)) + + def test_user_filter_matchers_intersect_without_expanding_selection(self) -> None: + providers: list[shared_types.AuthProvider] = [ + { + "serviceType": "builtin", + "serviceID": "", + "clientID": "", + "displayName": "Builtin", + "isBuiltin": True, + "configID": "", + } + ] + users = [ + self.make_user("user-1", "alice", True, "alice@example.com", True), + self.make_user("user-2", "bob", True, "bob@example.com", True), + self.make_user("user-3", "carol", True, "carol@example.com", False), + self.make_user("user-4", "dana", False, "dana@example.com", True), + ] + user_filters: dict[str, object] = { + "authProvider": {"type": "builtin"}, + "emails": ["alice@example.com", "carol@example.com", "dana@example.com"], + "usernames": ["alice", "bob", "carol"], + } + single_filter_usernames = { + name: self.usernames_for( + mapping.resolve_users({name: matcher}, users, providers), + ) + for name, matcher in user_filters.items() + } + + for filter_count in range(2, len(user_filters) + 1): + for filter_names in itertools.combinations(user_filters, filter_count): + matched_usernames = self.usernames_for( + mapping.resolve_users( + {name: user_filters[name] for name in filter_names}, + users, + providers, + ) + ) + expected_usernames = self.intersection_for(filter_names, single_filter_usernames) + + self.assertEqual(expected_usernames, matched_usernames) + for name in filter_names: + self.assertLessEqual(matched_usernames, single_filter_usernames[name]) + + self.assertEqual( + {"alice"}, + self.usernames_for(mapping.resolve_users(user_filters, users, providers)), + ) + + def test_repo_filter_matchers_intersect_without_expanding_selection(self) -> None: + sourcegraph_repo = self.make_repo("repo-1", "github.com/sourcegraph/sourcegraph") + example_private_repo = self.make_repo("repo-2", "github.com/example/private-repo") + gitlab_repo = self.make_repo("repo-3", "gitlab.com/example/private-repo") + example_public_repo = self.make_repo("repo-4", "github.com/example/public-repo") + all_repos = { + sourcegraph_repo["id"]: sourcegraph_repo, + example_private_repo["id"]: example_private_repo, + gitlab_repo["id"]: gitlab_repo, + example_public_repo["id"]: example_public_repo, + } + services_by_id = { + 1: self.make_external_service(1, "GITHUB", "GitHub Enterprise"), + 2: self.make_external_service(2, "GITHUB", "GitHub Cloud"), + } + repos_by_external_service_id = { + 1: [sourcegraph_repo, example_private_repo, gitlab_repo], + 2: [example_public_repo], + } + repo_filters: dict[str, object] = { + "codeHostConnection": {"id": 1}, + "names": [ + "github.com/example/private-repo", + "gitlab.com/example/private-repo", + ], + "regexes": [ + r"^github\.com/example/", + r"^gitlab\.com/example/", + ], + } + single_filter_repo_names = { + name: self.repo_names_for( + mapping.resolve_repos( + {name: matcher}, + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ) + for name, matcher in repo_filters.items() + } + + for filter_count in range(2, len(repo_filters) + 1): + for filter_names in itertools.combinations(repo_filters, filter_count): + matched_repo_names = self.repo_names_for( + mapping.resolve_repos( + {name: repo_filters[name] for name in filter_names}, + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ) + expected_repo_names = self.intersection_for(filter_names, single_filter_repo_names) + + self.assertEqual(expected_repo_names, matched_repo_names) + for name in filter_names: + self.assertLessEqual(matched_repo_names, single_filter_repo_names[name]) + + self.assertEqual( + {"github.com/example/private-repo", "gitlab.com/example/private-repo"}, + self.repo_names_for( + mapping.resolve_repos( + repo_filters, + services_by_id, + repos_by_external_service_id, + all_repos, + ) + ), + ) + + def test_validate_mapping_rules_accepts_string_list_filters(self) -> None: + mapping.validate_mapping_rules( + cast( + list[permission_types.MappingRule], + [ + { + "users": { + "emails": ["alice@example.com"], + "usernames": ["alice"], + }, + "repos": { + "names": ["github.com/example/private-repo"], + "regexes": [r"^github\.com/example/"], + }, + } + ], + ) + ) + + def test_repos_regexes_match_any_pattern(self) -> None: + sourcegraph_repo = self.make_repo("repo-1", "github.com/sourcegraph/sourcegraph") + github_repo = self.make_repo("repo-2", "github.com/example/private-repo") + gitlab_repo = self.make_repo("repo-3", "gitlab.com/example/private-repo") + all_repos = { + sourcegraph_repo["id"]: sourcegraph_repo, + github_repo["id"]: github_repo, + gitlab_repo["id"]: gitlab_repo, + } + + matched_repos = mapping.resolve_repos( + { + "regexes": [ + r"^github\.com/example/", + r"^gitlab\.com/example/", + ], + }, + {}, + {}, + all_repos, + ) + + self.assertEqual( + {"github.com/example/private-repo", "gitlab.com/example/private-repo"}, + self.repo_names_for(matched_repos), + ) + + def test_validate_mapping_rules_rejects_non_string_list_filters(self) -> None: + with self.assertRaises(SystemExit) as raised: + mapping.validate_mapping_rules( + cast( + list[permission_types.MappingRule], + [ + { + "users": { + "emails": "alice@example.com", + "usernames": [""], + }, + "repos": { + "names": [123], + "regexes": ["["], + }, + }, + { + "users": {"usernames": ["alice"]}, + "repos": {"regex": r"^github\.com/example/"}, + }, + ], + ) + ) + + message = str(raised.exception) + self.assertIn("users.emails must be a list of strings", message) + self.assertIn("users.usernames[0] is an empty string", message) + self.assertIn("repos.names[0] must be a string", message) + self.assertIn("repos.regexes[0] is not a valid Python regex", message) + self.assertIn("unknown repos matcher 'regex'", message) + + def make_user( + self, + user_id: str, + username: str, + builtin_auth: bool, + email: str, + verified: bool, + ) -> shared_types.User: + return { + "id": user_id, + "username": username, + "builtinAuth": builtin_auth, + "emails": [{"email": email, "verified": verified}], + "externalAccounts": {"nodes": []}, + } + + def make_repo(self, repo_id: str, name: str) -> permission_types.Repository: + return {"id": repo_id, "name": name} + + def make_external_service( + self, + external_service_id: int, + kind: str, + display_name: str, + ) -> permission_types.ExternalService: + graphql_id = base64.b64encode(f"ExternalService:{external_service_id}".encode()).decode() + return { + "id": graphql_id, + "kind": kind, + "displayName": display_name, + "url": f"https://code-host-{external_service_id}.example.com", + "repoCount": 0, + "createdAt": "2026-05-30T00:00:00Z", + "updatedAt": "2026-05-30T00:00:00Z", + "lastSyncAt": None, + "nextSyncAt": None, + "lastSyncError": None, + "warning": None, + "unrestricted": False, + "suspended": False, + "hasConnectionCheck": False, + "supportsRepoExclusion": False, + "creator": None, + "lastUpdater": None, + "config": "{}", + } + + def usernames_for(self, users: list[shared_types.User]) -> set[str]: + return {user["username"] for user in users} + + def repo_names_for(self, repos: list[permission_types.Repository]) -> set[str]: + return {repo["name"] for repo in repos} + + def intersection_for( + self, names: tuple[str, ...], sets_by_name: dict[str, set[str]] + ) -> set[str]: + matched = set(sets_by_name[names[0]]) + for name in names[1:]: + matched &= sets_by_name[name] + return matched + + +class FullSetPlanningTests(unittest.TestCase): + def test_full_set_plan_reuses_user_tuple_for_non_overlapping_repos(self) -> None: + users = [self.make_user("user-1", "bob"), self.make_user("user-2", "alice")] + repositories = [ + self.make_repo("repo-1", "github.com/example/one"), + self.make_repo("repo-2", "github.com/example/two"), + ] + context = self.make_context( + [ + { + "users": {"usernames": ["alice", "bob"]}, + "repos": {"names": ["github.com/example/one", "github.com/example/two"]}, + } + ], + repositories, + ) + + plan = full_set.plan_full_set_permissions(context, users) + + self.assertEqual(("alice", "bob"), plan.expected_users["repo-1"]) + self.assertEqual(("alice", "bob"), plan.expected_users["repo-2"]) + self.assertIs(plan.expected_users["repo-1"], plan.expected_users["repo-2"]) + self.assertEqual(4, plan.total_grants) + + def test_full_set_plan_unions_only_overlapping_repos(self) -> None: + users = [ + self.make_user("user-1", "alice"), + self.make_user("user-2", "bob"), + self.make_user("user-3", "chris"), + ] + repositories = [ + self.make_repo("repo-1", "github.com/example/one"), + self.make_repo("repo-2", "github.com/example/two"), + self.make_repo("repo-3", "github.com/example/three"), + ] + context = self.make_context( + [ + { + "users": {"usernames": ["alice", "bob"]}, + "repos": {"names": ["github.com/example/one", "github.com/example/two"]}, + }, + { + "users": {"usernames": ["bob", "chris"]}, + "repos": {"names": ["github.com/example/two", "github.com/example/three"]}, + }, + ], + repositories, + ) + + plan = full_set.plan_full_set_permissions(context, users) + + self.assertEqual(("alice", "bob"), plan.expected_users["repo-1"]) + self.assertEqual(("alice", "bob", "chris"), plan.expected_users["repo-2"]) + self.assertEqual(("bob", "chris"), plan.expected_users["repo-3"]) + self.assertEqual(7, plan.total_grants) + + def make_context( + self, + mapping_rules: list[permission_types.MappingRule], + repositories: list[permission_types.Repository], + ) -> permission_types.MappingContext: + return permission_types.MappingContext( + mapping_rules=mapping_rules, + providers=[], + saml_groups_attribute_names={}, + services_by_id={}, + repos_by_external_service_id={}, + all_repos_by_id={repository["id"]: repository for repository in repositories}, + ) + + def make_user(self, user_id: str, username: str) -> shared_types.User: + return { + "id": user_id, + "username": username, + "builtinAuth": True, + "emails": [], + "externalAccounts": {"nodes": []}, + } + + def make_repo(self, repo_id: str, name: str) -> permission_types.Repository: + return {"id": repo_id, "name": name} + + +class QueryTests(unittest.TestCase): + def test_user_email_fields_are_opt_in(self) -> None: + self.assertNotIn("emails {", shared_queries.QUERY_USERS) + self.assertNotIn("emails {", shared_queries.query_users()) + self.assertIn("emails {", shared_queries.query_users(include_emails=True)) + + self.assertNotIn("emails {", permission_queries.QUERY_USER_BY_ID) + self.assertNotIn("emails {", permission_queries.query_user_by_id()) + self.assertIn("emails {", permission_queries.query_user_by_id(include_emails=True)) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 3525fc9..bcb72f9 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -177,22 +177,14 @@ def test_list_users_explicit_repos_batches_aliases_and_follows_pages(self) -> No ), ] - class FakeGraphQLClient: - def __init__(self, **_kwargs: object) -> None: - pass - - def execute( - self, - query: str, - variables: src.JSONDict, - *, - follow_pages: bool = True, - ) -> src.JSONDict: - calls.append((query, dict(variables), follow_pages)) - return responses.pop(0) - - def graphql(query: str, variables: object = None) -> src.JSONDict: - return FakeGraphQLClient().execute(query, cast(src.JSONDict, variables or {})) + def graphql( + query: str, + variables: object = None, + *, + follow_pages: bool = True, + ) -> src.JSONDict: + calls.append((query, dict(cast(src.JSONDict, variables or {})), follow_pages)) + return responses.pop(0) client = cast( src.SourcegraphClient, @@ -203,12 +195,11 @@ def graphql(query: str, variables: object = None) -> src.JSONDict: graphql=graphql, ), ) - with patch.object(permissions_sourcegraph.src, "GraphQLClient", FakeGraphQLClient): - repos_by_user_id = permissions_sourcegraph.list_users_explicit_repos( - client, - ["user-1", "user-2"], - batch_size=2, - ) + repos_by_user_id = permissions_sourcegraph.list_users_explicit_repos( + client, + ["user-1", "user-2"], + batch_size=2, + ) self.assertEqual( { diff --git a/uv.lock b/uv.lock index e759784..12eb8bd 100644 --- a/uv.lock +++ b/uv.lock @@ -337,7 +337,7 @@ dev = [ requires-dist = [ { name = "json5", specifier = ">=0.14.0" }, { name = "pyyaml", specifier = ">=6.0.3" }, - { name = "src-py-lib", specifier = "==0.1.5" }, + { name = "src-py-lib", specifier = "==0.1.6" }, ] [package.metadata.requires-dev] @@ -349,16 +349,16 @@ dev = [ [[package]] name = "src-py-lib" -version = "0.1.5" +version = "0.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, { name = "pydantic" }, { name = "python-dotenv" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/68/39/dc534f18686d255141982cae8d7935c3cb807a6b98356d8936ed9c2d3b3d/src_py_lib-0.1.5.tar.gz", hash = "sha256:695f0fc0a2c539bd7ffc6c537822dca604fe8718de343c34f973765b31201d69", size = 71613, upload-time = "2026-05-29T08:58:25.545Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/e7/bc59bf44fc2130df83aeef64dc2666f4617f236b61b879dc8d5629609361/src_py_lib-0.1.6.tar.gz", hash = "sha256:e2c5b015e2bb077e6116ad7457654cc81d17d13bc9f05768fa6720719d350f93", size = 71768, upload-time = "2026-05-29T15:19:45.891Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/ee/550cbda36b6853584f60f4acbbae781f2e0a38e12811fdcd8731532ed077/src_py_lib-0.1.5-py3-none-any.whl", hash = "sha256:1bafff027ccb68478d5712a5522e7e21dd4ef5fe51b14723fff95dbd6496db30", size = 44873, upload-time = "2026-05-29T08:58:24.475Z" }, + { url = "https://files.pythonhosted.org/packages/7e/eb/cddb0c92806cbb3ebfed4baa32ee6fa8550a2f5ac56c55de92e65f0066ba/src_py_lib-0.1.6-py3-none-any.whl", hash = "sha256:781c82fa42f48268a3b8b1ac7406fa69418dfd3d0ba3bc795b549093d004647a", size = 44956, upload-time = "2026-05-29T15:19:44.559Z" }, ] [[package]]