diff --git a/router-tests/events/nats_events_test.go b/router-tests/events/nats_events_test.go index 7b3f3822b5..5f242d3237 100644 --- a/router-tests/events/nats_events_test.go +++ b/router-tests/events/nats_events_test.go @@ -36,11 +36,11 @@ var ( type ConfigPollerMock struct { initConfig *nodev1.RouterConfig - updateConfig func(newConfig *nodev1.RouterConfig, oldVersion string) error + updateConfig func(newConfig *routerconfig.Response) error ready chan struct{} } -func (c *ConfigPollerMock) Subscribe(_ context.Context, handler func(newConfig *nodev1.RouterConfig, oldVersion string) error) { +func (c *ConfigPollerMock) Subscribe(_ context.Context, handler func(newConfig *routerconfig.Response) error) { c.updateConfig = handler close(c.ready) } @@ -1595,7 +1595,7 @@ func TestNatsEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(3, EventWaitTimeout) // Swap config - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) // Wait for all providers to shut down and restart require.Eventually(t, func() bool { diff --git a/router-tests/operations/cache_warmup_test.go b/router-tests/operations/cache_warmup_test.go index 2f9c662abd..dc8b0d144a 100644 --- a/router-tests/operations/cache_warmup_test.go +++ b/router-tests/operations/cache_warmup_test.go @@ -1016,7 +1016,7 @@ func TestInMemoryPlanCacheFallback(t *testing.T) { <-pm.ready pm.initConfig.Version = "updated" - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `{ employees { id } }`, @@ -1061,7 +1061,7 @@ func TestInMemoryPlanCacheFallback(t *testing.T) { <-pm.ready pm.initConfig.Version = "updated" - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `{ employees { id } }`, @@ -1110,7 +1110,7 @@ func TestInMemoryPlanCacheFallback(t *testing.T) { <-pm.ready pm.initConfig.Version = "updated" - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `{ employees { id customDetails: details { forename } } }`, @@ -1412,11 +1412,11 @@ func writeTestConfig(t *testing.T, version string, path string) { type ConfigPollerMock struct { initConfig *nodev1.RouterConfig - updateConfig func(newConfig *nodev1.RouterConfig, oldVersion string) error + updateConfig func(newConfig *routerconfig.Response) error ready chan struct{} } -func (c *ConfigPollerMock) Subscribe(_ context.Context, handler func(newConfig *nodev1.RouterConfig, oldVersion string) error) { +func (c *ConfigPollerMock) Subscribe(_ context.Context, handler func(_ *routerconfig.Response) error) { c.updateConfig = handler close(c.ready) } diff --git a/router-tests/operations/plan_fallback_cache_test.go b/router-tests/operations/plan_fallback_cache_test.go index 1b00ce0181..0ff63d461e 100644 --- a/router-tests/operations/plan_fallback_cache_test.go +++ b/router-tests/operations/plan_fallback_cache_test.go @@ -12,6 +12,7 @@ import ( nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/controlplane/configpoller" + "github.com/wundergraph/cosmo/router/pkg/routerconfig" ) func TestPlanFallbackCache(t *testing.T) { @@ -169,7 +170,7 @@ func TestPlanFallbackCache(t *testing.T) { // Trigger config reload — new Ristretto cache is created (size 1). <-pm.ready pm.initConfig.Version = "updated" - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) // After reload, slow queries should still be available via fallback cache. waitForPlanCacheHits(t, xEnv, slowQueries, func(ct *assert.CollectT, res *testenv.TestResponse) { @@ -230,7 +231,7 @@ func TestPlanFallbackCache(t *testing.T) { // Trigger config reload — main plan cache is reset. <-pm.ready pm.initConfig.Version = "updated" - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) // Wait for reload to complete by checking a slow query (which will be // served from the fallback cache, confirming the new server is active). @@ -293,7 +294,7 @@ func TestPlanFallbackCache(t *testing.T) { // First reload pm.initConfig.Version = "v2" - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) waitForPlanCacheHits(t, xEnv, slowQueries, func(ct *assert.CollectT, res *testenv.TestResponse) { assert.Equal(ct, "v2", res.Response.Header.Get("X-Router-Config-Version")) @@ -301,7 +302,7 @@ func TestPlanFallbackCache(t *testing.T) { // Second reload pm.initConfig.Version = "v3" - require.NoError(t, pm.updateConfig(pm.initConfig, "v2")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) waitForPlanCacheHits(t, xEnv, slowQueries, func(ct *assert.CollectT, res *testenv.TestResponse) { assert.Equal(ct, "v3", res.Response.Header.Get("X-Router-Config-Version")) diff --git a/router-tests/protocol/config_hot_reload_test.go b/router-tests/protocol/config_hot_reload_test.go index 447fff851e..6221ebd5df 100644 --- a/router-tests/protocol/config_hot_reload_test.go +++ b/router-tests/protocol/config_hot_reload_test.go @@ -5,15 +5,16 @@ import ( "context" "encoding/json" + "os" + "sync/atomic" + "testing" + "time" + "github.com/wundergraph/cosmo/router/pkg/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/sdk/metric" "go.opentelemetry.io/otel/sdk/metric/metricdata" "go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest" - "os" - "sync/atomic" - "testing" - "time" "github.com/wundergraph/cosmo/router/pkg/routerconfig" @@ -33,11 +34,11 @@ var ( type ConfigPollerMock struct { initConfig *nodev1.RouterConfig - updateConfig func(newConfig *nodev1.RouterConfig, oldVersion string) error + updateConfig func(response *routerconfig.Response) error ready chan struct{} } -func (c *ConfigPollerMock) Subscribe(_ context.Context, handler func(newConfig *nodev1.RouterConfig, oldVersion string) error) { +func (c *ConfigPollerMock) Subscribe(_ context.Context, handler func(response *routerconfig.Response) error) { c.updateConfig = handler close(c.ready) } @@ -86,7 +87,7 @@ func TestConfigHotReloadPoller(t *testing.T) { <-pm.ready pm.initConfig.Version = "updated" - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `{ employees { id } }`, @@ -156,7 +157,7 @@ func TestConfigHotReloadPoller(t *testing.T) { // Swap config pm.initConfig.Version = "updated" - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `{ employees { id } }`, @@ -223,7 +224,7 @@ func TestConfigHotReloadPoller(t *testing.T) { // Swap config — the ReadJSON below expects a possible websocket close error, // so use a deadline instead of WSReadJSON (which retries on errors) - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) conn.SetReadDeadline(time.Now().Add(5 * time.Second)) err = conn.ReadJSON(&msg) conn.SetReadDeadline(time.Time{}) @@ -657,7 +658,7 @@ func BenchmarkConfigHotReload(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + require.NoError(t, pm.updateConfig(&routerconfig.Response{Config: pm.initConfig})) } }) diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 23812b8fe4..6fc67b68ca 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -22,6 +22,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/klauspost/compress/gzhttp" "github.com/klauspost/compress/gzip" + "github.com/wundergraph/cosmo/router/pkg/routerconfig" "go.opentelemetry.io/otel/attribute" otelmetric "go.opentelemetry.io/otel/metric" oteltrace "go.opentelemetry.io/otel/trace" @@ -91,9 +92,11 @@ type ( baseRouterConfigVersion string mux *chi.Mux // inFlightRequests is used to track the number of requests currently being processed - // does not include websocket (hijacked) connections - inFlightRequests *atomic.Uint64 - graphMuxList []*graphMux + // does not include websocket (hijacked) connections. + inFlightRequests *atomic.Uint64 + // graphMuxList contains all graph muxes of this graph server. + // It's keyed by mux name (feature flag name or empty string for base graph). + graphMuxList map[string]*graphMux graphMuxListLock sync.Mutex runtimeMetrics *rmetric.RuntimeMetrics otlpEngineMetrics *rmetric.EngineMetrics @@ -127,18 +130,20 @@ type buildMultiGraphHandlerOptions struct { baseMux *chi.Mux featureFlagConfigs map[string]*nodev1.FeatureFlagRouterExecutionConfig reloadPersistentState *ReloadPersistentState + currentGraphMuxes map[string]*graphMux + changes *routerconfig.Changes } // newGraphServer creates a new server instance. -func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterConfig, proxy ProxyFunc) (*graphServer, error) { +func newGraphServer(ctx context.Context, r *Router, response *routerconfig.Response, proxy ProxyFunc) (*graphServer, error) { /* Older versions of composition will not populate a compatibility version. * Currently, all "old" router execution configurations are compatible as there have been no breaking * changes. * Upon the first breaking change to the execution config, an unpopulated compatibility version will * also be unsupported (and the logic for IsRouterCompatibleWithExecutionConfig will need to be updated). */ - if !execution_config.IsRouterCompatibleWithExecutionConfig(r.logger, routerConfig.CompatibilityVersion) { - return nil, fmt.Errorf(`the compatibility version "%s" is not compatible with this router version`, routerConfig.CompatibilityVersion) + if !execution_config.IsRouterCompatibleWithExecutionConfig(r.logger, response.Config.CompatibilityVersion) { + return nil, fmt.Errorf(`the compatibility version "%s" is not compatible with this router version`, response.Config.CompatibilityVersion) } isConnStoreEnabled := r.metricConfig.OpenTelemetry.ConnectionStats || r.metricConfig.Prometheus.ConnectionStats @@ -186,9 +191,9 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC subgraphTransports: subgraphTransports, playgroundHandler: r.playgroundHandler, traceDialer: traceDialer, - baseRouterConfigVersion: routerConfig.GetVersion(), + baseRouterConfigVersion: response.Config.GetVersion(), inFlightRequests: &atomic.Uint64{}, - graphMuxList: make([]*graphMux, 0, 1), + graphMuxList: make(map[string]*graphMux, 1), instanceData: InstanceData{ HostName: r.hostName, ListenAddress: r.listenAddr, @@ -294,23 +299,38 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC s.circuitBreakerManager = manager } - routingUrlGroupings, err := getRoutingUrlGroupingForCircuitBreakers(routerConfig, s.overrideRoutingURLConfiguration, s.overrides) + routingUrlGroupings, err := getRoutingUrlGroupingForCircuitBreakers(response.Config, s.overrideRoutingURLConfiguration, s.overrides) if err != nil { return nil, err } - gm, err := s.buildGraphMux(ctx, BuildGraphMuxOptions{ - RouterConfigVersion: s.baseRouterConfigVersion, - EngineConfig: routerConfig.GetEngineConfig(), - ConfigSubgraphs: routerConfig.GetSubgraphs(), - RoutingUrlGroupings: routingUrlGroupings, - ReloadPersistentState: r.reloadPersistentState, - }) - if err != nil { - return nil, fmt.Errorf("failed to build base mux: %w", err) + currentMuxes := currentGraphMuxes(r) + var gm *graphMux + + mux, oldBaseGraphMuxExists := currentMuxes[""] + needNewBaseGraphMux := response.Changes == nil || response.Changes.BaseGraphChanged() || !oldBaseGraphMuxExists + + if needNewBaseGraphMux { + // build new base grap mux + gm, err = s.buildGraphMux(ctx, BuildGraphMuxOptions{ + RouterConfigVersion: s.baseRouterConfigVersion, + EngineConfig: response.Config.GetEngineConfig(), + ConfigSubgraphs: response.Config.GetSubgraphs(), + RoutingUrlGroupings: routingUrlGroupings, + ReloadPersistentState: r.reloadPersistentState, + }) + if err != nil { + return nil, fmt.Errorf("failed to build base mux: %w", err) + } + } else { + gm = mux + gm.reused.Store(true) + s.graphMuxListLock.Lock() + s.graphMuxList[""] = gm + s.graphMuxListLock.Unlock() } - featureFlagConfigMap := routerConfig.FeatureFlagConfigs.GetConfigByFeatureFlagName() + featureFlagConfigMap := response.Config.FeatureFlagConfigs.GetConfigByFeatureFlagName() if len(featureFlagConfigMap) > 0 { s.logger.Info("Feature flags enabled", zap.Strings("flags", maps.Keys(featureFlagConfigMap))) } @@ -319,6 +339,8 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC baseMux: gm.mux, featureFlagConfigs: featureFlagConfigMap, reloadPersistentState: r.reloadPersistentState, + currentGraphMuxes: currentMuxes, + changes: response.Changes, }) if err != nil { return nil, fmt.Errorf("failed to build feature flag handler: %w", err) @@ -477,6 +499,24 @@ func (s *graphServer) buildMultiGraphHandler( // Build all the muxes for the feature flags in serial to avoid any race conditions for featureFlagName, executionConfig := range opts.featureFlagConfigs { + if opts.changes != nil { + // if the ff is unchanged and still needed, we reuse it + _, hasChanged := opts.changes.ChangedConfigs[featureFlagName] + _, wasAdded := opts.changes.AddedConfigs[featureFlagName] + + if !hasChanged && !wasAdded { + oldGraphMux, exists := opts.currentGraphMuxes[featureFlagName] + if exists { + featureFlagToMux[featureFlagName] = oldGraphMux.mux + s.graphMuxListLock.Lock() + s.graphMuxList[featureFlagName] = oldGraphMux + s.graphMuxListLock.Unlock() + oldGraphMux.reused.Store(true) + continue + } + } + } + gm, err := s.buildGraphMux(ctx, BuildGraphMuxOptions{ FeatureFlagName: featureFlagName, RouterConfigVersion: executionConfig.GetVersion(), @@ -544,7 +584,8 @@ func (s *graphServer) setupEngineStatistics(baseAttributes []attribute.KeyValue) } type graphMux struct { - mux *chi.Mux + mux *chi.Mux + reused atomic.Bool planCache *ristretto.Cache[uint64, *planWithMetaData] planFallbackCache *slowplancache.Cache[*planWithMetaData] @@ -1734,7 +1775,7 @@ func (s *graphServer) buildGraphMux( s.graphMuxListLock.Lock() defer s.graphMuxListLock.Unlock() - s.graphMuxList = append(s.graphMuxList, gm) + s.graphMuxList[opts.FeatureFlagName] = gm return gm, nil } @@ -1982,11 +2023,15 @@ func (s *graphServer) Shutdown(ctx context.Context) error { } } - // Shutdown all graphs muxes to release resources + // Shutdown graphs muxes, which are not reused by the next graph server, to release resources // e.g. planner cache s.graphMuxListLock.Lock() defer s.graphMuxListLock.Unlock() for _, mux := range s.graphMuxList { + if mux.reused.Load() { + mux.reused.Store(false) // set to false to avoid the mux from being skipped forever + continue + } if err := mux.Shutdown(ctx); err != nil { finalErr = errors.Join(finalErr, err) } @@ -2171,3 +2216,22 @@ func configureSubgraphOverwrites( return subgraphs, nil } + +// currentGraphMuxes returns a list of currently active graph muxes +// used by the currently running graph server. +func currentGraphMuxes(r *Router) map[string]*graphMux { + currentState := r.httpServer.state.Load() + if currentState == nil { + return nil + } + + currentGraphServer := currentState.graphServer + if currentGraphServer == nil { + return nil + } + + currentGraphServer.graphMuxListLock.Lock() + defer currentGraphServer.graphMuxListLock.Unlock() + + return maps.Clone(currentGraphServer.graphMuxList) +} diff --git a/router/core/init_config_poller.go b/router/core/init_config_poller.go index 2748be71f5..38de03c448 100644 --- a/router/core/init_config_poller.go +++ b/router/core/init_config_poller.go @@ -136,6 +136,7 @@ func InitializeConfigPoller(r *Router, registry *ProviderRegistry) (*configpolle if hasSplitCfgFeature { providerID := r.routerConfigPollerConfig.Storage.ProviderID if providerID == "" { + r.logger.Debug("Use split-config poller to fetch execution config") return newSplitConfigPoller(r) } r.logger.Info("split-config-loading feature is enabled but a custom storage provider is configured; falling back to regular config polling", diff --git a/router/core/router.go b/router/core/router.go index cb173417d3..4dc4a5a9fe 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -17,6 +17,7 @@ import ( "connectrpc.com/connect" "github.com/mitchellh/mapstructure" "github.com/nats-io/nuid" + "github.com/wundergraph/cosmo/router/pkg/routerconfig" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/propagation" @@ -602,8 +603,8 @@ func NewRouter(opts ...Option) (*Router, error) { } // newGraphServer creates a new server. -func (r *Router) newServer(ctx context.Context, cfg *nodev1.RouterConfig) error { - server, err := newGraphServer(ctx, r, cfg, r.proxy) +func (r *Router) newServer(ctx context.Context, response *routerconfig.Response) error { + server, err := newGraphServer(ctx, r, response, r.proxy) if err != nil { r.logger.Error("Failed to create graph server. Keeping the old server", zap.Error(err)) return err @@ -612,7 +613,7 @@ func (r *Router) newServer(ctx context.Context, cfg *nodev1.RouterConfig) error r.httpServer.SwapGraphServer(ctx, server) // Cleanup any unused feature flags in case a feature flag was removed - r.reloadPersistentState.CleanupFeatureFlags(cfg) + r.reloadPersistentState.CleanupFeatureFlags(response.Config) return nil } @@ -792,7 +793,7 @@ func (r *Router) NewServer(ctx context.Context) (Server, error) { // Start the server with the static config without polling if r.staticExecutionConfig != nil { r.logger.Info("Static execution config provided. Polling is disabled. Updating execution config is only possible by providing a config.") - return r.httpServer, r.newServer(ctx, r.staticExecutionConfig) + return r.httpServer, r.newServer(ctx, &routerconfig.Response{Config: r.staticExecutionConfig}) } // when no static config is provided and no poller is configured, we can't start the server @@ -805,7 +806,7 @@ func (r *Router) NewServer(ctx context.Context) (Server, error) { return nil, fmt.Errorf("failed to get initial execution config: %w", err) } - if err := r.newServer(ctx, cfg.Config); err != nil { + if err := r.newServer(ctx, cfg); err != nil { r.logger.Error("Failed to start server with initial config", zap.Error(err)) return nil, err } @@ -1463,7 +1464,7 @@ func (r *Router) Start(ctx context.Context) error { return err } - if err := r.newServer(ctx, r.staticExecutionConfig); err != nil { + if err := r.newServer(ctx, &routerconfig.Response{Config: r.staticExecutionConfig}); err != nil { return err } @@ -1507,7 +1508,7 @@ func (r *Router) Start(ctx context.Context) error { return } - if err := r.newServer(ctx, cfg); err != nil { + if err := r.newServer(ctx, &routerconfig.Response{Config: cfg}); err != nil { ll.Error("Failed to update server with new config", zap.Error(err)) return } @@ -1557,7 +1558,7 @@ func (r *Router) Start(ctx context.Context) error { return err } - if err := r.newServer(ctx, cfg.Config); err != nil { + if err := r.newServer(ctx, cfg); err != nil { return err } @@ -1591,15 +1592,15 @@ func (r *Router) Start(ctx context.Context) error { ) } - r.configPoller.Subscribe(ctx, func(newConfig *nodev1.RouterConfig, oldVersion string) error { + r.configPoller.Subscribe(ctx, func(response *routerconfig.Response) error { if r.shutdown.Load() { r.logger.Warn("Router is in shutdown state. Skipping config update") return nil } - r.trackExecutionConfigUsage(newConfig, false) + r.trackExecutionConfigUsage(response.Config, false) - if err := r.newServer(ctx, newConfig); err != nil { + if err := r.newServer(ctx, response); err != nil { return err } diff --git a/router/pkg/controlplane/configpoller/config_poller.go b/router/pkg/controlplane/configpoller/config_poller.go index 0cb9b20f69..0c34bd2aa2 100644 --- a/router/pkg/controlplane/configpoller/config_poller.go +++ b/router/pkg/controlplane/configpoller/config_poller.go @@ -7,7 +7,6 @@ import ( "github.com/wundergraph/cosmo/router/pkg/routerconfig" - nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" "github.com/wundergraph/cosmo/router/pkg/controlplane" "go.uber.org/zap" ) @@ -19,9 +18,9 @@ var ErrConfigNotFound = errors.New("config not found") type ConfigPoller interface { // Subscribe subscribes to the config poller with a handler function that will be invoked - // with the latest router config and the previous version string. If the handler takes longer than the poll interval + // with the latest router config. If the handler takes longer than the poll interval // to execute, the next invocation will be skipped. - Subscribe(ctx context.Context, handler func(newConfig *nodev1.RouterConfig, oldVersion string) error) + Subscribe(ctx context.Context, handler func(response *routerconfig.Response) error) // GetRouterConfig returns the latest router config from the CDN // If the Config is nil, no new config is available and the current config should be used. // and updates the latest router config version. This method is only used for the initial config @@ -63,7 +62,7 @@ func (c *configPoller) Version() string { return c.latestRouterConfigVersion } -func (c *configPoller) Subscribe(ctx context.Context, handler func(newConfig *nodev1.RouterConfig, _ string) error) { +func (c *configPoller) Subscribe(ctx context.Context, handler func(newConfig *routerconfig.Response) error) { c.poller.Subscribe(ctx, func() { start := time.Now() @@ -99,7 +98,12 @@ func (c *configPoller) Subscribe(ctx context.Context, handler func(newConfig *no start = time.Now() - if err := handler(cfg.Config, c.latestRouterConfigVersion); err != nil { + response := &routerconfig.Response{ + Config: cfg.Config, + Changes: nil, // purposefully leaving this nil to indicate we don't know what changed + } + + if err := handler(response); err != nil { c.logger.Error("Error invoking config poll handler", zap.Error(err)) return } diff --git a/router/pkg/controlplane/configpoller/split_config_poller.go b/router/pkg/controlplane/configpoller/split_config_poller.go index ed661e4efe..2cd3a72c94 100644 --- a/router/pkg/controlplane/configpoller/split_config_poller.go +++ b/router/pkg/controlplane/configpoller/split_config_poller.go @@ -3,6 +3,7 @@ package configpoller import ( "context" "fmt" + "maps" "slices" "time" @@ -146,11 +147,16 @@ func (p *splitConfigPoller) GetRouterConfig(ctx context.Context) (*routerconfig. p.currentConfig = config p.latestVersion = computeCompositeVersion(graphConfigs) - return &routerconfig.Response{Config: config}, nil + response := &routerconfig.Response{ + Config: config, + Changes: nil, // purposefully nil to tell callers to rebuild everything since this is the initial fetch + } + + return response, nil } // Subscribe starts the polling loop and calls handler whenever the assembled config changes. -func (p *splitConfigPoller) Subscribe(ctx context.Context, handler func(newConfig *nodev1.RouterConfig, oldVersion string) error) { +func (p *splitConfigPoller) Subscribe(ctx context.Context, handler func(response *routerconfig.Response) error) { p.poller.Subscribe(ctx, func() { fetchStart := time.Now() @@ -184,20 +190,22 @@ func (p *splitConfigPoller) Subscribe(ctx context.Context, handler func(newConfi ) // Determine what changed, was added, or was removed. - changed := make(map[string]struct{}) - added := make(map[string]struct{}) - removed := make(map[string]struct{}) + changes := routerconfig.Changes{ + AddedConfigs: make(map[string]struct{}), + RemovedConfigs: make(map[string]struct{}), + ChangedConfigs: make(map[string]struct{}), + } for name, hash := range graphConfigs { if oldHash, exists := p.knownHashes[name]; !exists { - added[name] = struct{}{} + changes.AddedConfigs[name] = struct{}{} } else if oldHash != hash { - changed[name] = struct{}{} + changes.ChangedConfigs[name] = struct{}{} } } for name := range p.knownHashes { if _, exists := graphConfigs[name]; !exists { - removed[name] = struct{}{} + changes.RemovedConfigs[name] = struct{}{} } } @@ -205,13 +213,9 @@ func (p *splitConfigPoller) Subscribe(ctx context.Context, handler func(newConfi patched := proto.Clone(p.currentConfig).(*nodev1.RouterConfig) // Apply changes and additions. - toFetch := make(map[string]struct{}, len(changed)+len(added)) - for name := range changed { - toFetch[name] = struct{}{} - } - for name := range added { - toFetch[name] = struct{}{} - } + toFetch := make(map[string]struct{}, len(changes.ChangedConfigs)+len(changes.AddedConfigs)) + maps.Copy(toFetch, changes.ChangedConfigs) + maps.Copy(toFetch, changes.AddedConfigs) for name := range toFetch { fetchedConfig, err := p.fetcher.FetchConfig(ctx, name) @@ -244,7 +248,7 @@ func (p *splitConfigPoller) Subscribe(ctx context.Context, handler func(newConfi } // Remove deleted feature flags. - for name := range removed { + for name := range changes.RemovedConfigs { if name == "" { continue // base graph cannot be removed } @@ -256,10 +260,13 @@ func (p *splitConfigPoller) Subscribe(ctx context.Context, handler func(newConfi } } - oldVersion := p.latestVersion + response := &routerconfig.Response{ + Config: patched, + Changes: &changes, + } handlerStart := time.Now() - if err := handler(patched, oldVersion); err != nil { + if err := handler(response); err != nil { p.logger.Error("Error invoking config poll handler", zap.Error(err)) return } @@ -269,7 +276,8 @@ func (p *splitConfigPoller) Subscribe(ctx context.Context, handler func(newConfi zap.String("config_version", newVersion), ) - // Only update internal state after the handler succeeds. + // Only update internal state after the handler succeeds, + // i.e. the newly created engine config is actually used by the graph server. p.knownHashes = graphConfigs p.currentConfig = patched p.latestVersion = newVersion diff --git a/router/pkg/controlplane/configpoller/split_config_poller_test.go b/router/pkg/controlplane/configpoller/split_config_poller_test.go index 6760ce89c9..620f09fb87 100644 --- a/router/pkg/controlplane/configpoller/split_config_poller_test.go +++ b/router/pkg/controlplane/configpoller/split_config_poller_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/routerconfig" "go.uber.org/zap" ) @@ -150,7 +151,7 @@ func TestSplitGetRouterConfig_ConfigFetchError(t *testing.T) { // pollOnce manually executes one poll iteration using the poller's internal logic. // It extracts the subscribe callback by using a fake controlplane.Poller. -func pollOnce(p *splitConfigPoller, handler func(*nodev1.RouterConfig, string) error) { +func pollOnce(p *splitConfigPoller, handler func(_ *routerconfig.Response) error) { var tickFn func() p.poller = &capturingPoller{capture: &tickFn} p.Subscribe(context.Background(), handler) // sets tickFn @@ -182,7 +183,7 @@ func TestSplitSubscribe_NoChanges(t *testing.T) { p.latestVersion = computeCompositeVersion(p.knownHashes) handlerCalled := false - pollOnce(p, func(_ *nodev1.RouterConfig, _ string) error { + pollOnce(p, func(_ *routerconfig.Response) error { handlerCalled = true return nil }) @@ -207,8 +208,8 @@ func TestSplitSubscribe_BaseGraphChanged(t *testing.T) { p.latestVersion = computeCompositeVersion(p.knownHashes) var received *nodev1.RouterConfig - pollOnce(p, func(cfg *nodev1.RouterConfig, _ string) error { - received = cfg + pollOnce(p, func(resp *routerconfig.Response) error { + received = resp.Config return nil }) @@ -246,8 +247,8 @@ func TestSplitSubscribe_SingleFFChanged(t *testing.T) { p.latestVersion = computeCompositeVersion(p.knownHashes) var received *nodev1.RouterConfig - pollOnce(p, func(cfg *nodev1.RouterConfig, _ string) error { - received = cfg + pollOnce(p, func(resp *routerconfig.Response) error { + received = resp.Config return nil }) @@ -278,8 +279,8 @@ func TestSplitSubscribe_FFAdded(t *testing.T) { p.latestVersion = computeCompositeVersion(p.knownHashes) var received *nodev1.RouterConfig - pollOnce(p, func(cfg *nodev1.RouterConfig, _ string) error { - received = cfg + pollOnce(p, func(resp *routerconfig.Response) error { + received = resp.Config return nil }) @@ -311,8 +312,8 @@ func TestSplitSubscribe_FFRemoved(t *testing.T) { p.latestVersion = computeCompositeVersion(p.knownHashes) var received *nodev1.RouterConfig - pollOnce(p, func(cfg *nodev1.RouterConfig, _ string) error { - received = cfg + pollOnce(p, func(resp *routerconfig.Response) error { + received = resp.Config return nil }) @@ -353,8 +354,8 @@ func TestSplitSubscribe_MultipleChanges(t *testing.T) { p.latestVersion = computeCompositeVersion(p.knownHashes) var received *nodev1.RouterConfig - pollOnce(p, func(cfg *nodev1.RouterConfig, _ string) error { - received = cfg + pollOnce(p, func(resp *routerconfig.Response) error { + received = resp.Config return nil }) @@ -379,7 +380,7 @@ func TestSplitSubscribe_MapperFetchFailure(t *testing.T) { p.latestVersion = initialVersion handlerCalled := false - pollOnce(p, func(_ *nodev1.RouterConfig, _ string) error { + pollOnce(p, func(_ *routerconfig.Response) error { handlerCalled = true return nil }) @@ -403,7 +404,7 @@ func TestSplitSubscribe_ConfigFetchFailure(t *testing.T) { p.latestVersion = initialVersion handlerCalled := false - pollOnce(p, func(_ *nodev1.RouterConfig, _ string) error { + pollOnce(p, func(_ *routerconfig.Response) error { handlerCalled = true return nil }) @@ -428,7 +429,7 @@ func TestSplitSubscribe_HandlerError_StateNotUpdated(t *testing.T) { initialVersion := computeCompositeVersion(p.knownHashes) p.latestVersion = initialVersion - pollOnce(p, func(_ *nodev1.RouterConfig, _ string) error { + pollOnce(p, func(_ *routerconfig.Response) error { return errors.New("handler failed") }) diff --git a/router/pkg/routerconfig/client.go b/router/pkg/routerconfig/client.go index 8537d72006..2d1e50c8f5 100644 --- a/router/pkg/routerconfig/client.go +++ b/router/pkg/routerconfig/client.go @@ -11,6 +11,21 @@ import ( type Response struct { // Config is the marshaled router config Config *nodev1.RouterConfig + // Changes is a summary of which parts of Config + // have changed since the last successful config apply. + // Nil means changes are unknown -> expect everything to be changed. + Changes *Changes +} + +type Changes struct { + AddedConfigs map[string]struct{} + RemovedConfigs map[string]struct{} + ChangedConfigs map[string]struct{} +} + +func (c *Changes) BaseGraphChanged() bool { + _, exists := c.ChangedConfigs[""] + return exists } type Client interface {