From 67748946ae321221c71cc3887aaacd13df42692c Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 24 Jun 2026 12:52:45 +0800 Subject: [PATCH 01/16] update tinker client sdk --- .../server/transformer/server_config.yaml | 70 ++++++++-------- .../tinker/modelscope/short_math_grpo.py | 2 +- .../tinker/self_host/short_math_grpo.py | 2 +- src/twinkle/server/model/tinker_handlers.py | 14 ++-- .../integration/test_mock_mode_startup.py | 11 +++ tests/server/model/test_tinker_handlers.py | 79 +++++++++++++++++++ 6 files changed, 136 insertions(+), 42 deletions(-) diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 183bfff24..b5d8497fd 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -11,7 +11,7 @@ http_options: # Telemetry: push traces/metrics/logs to LGTM's OTel Collector via OTLP telemetry: - enabled: true + enabled: false otlp_endpoint: http://localhost:4317 # Persistence configuration for ServerState (sessions, models, futures, ...). @@ -81,43 +81,43 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_TRUST_REMOTE_CODE: "1" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - # - name: sampler-Qwen3.5-4B - # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - # nproc_per_node: 2 # Number of GPU processes per node - # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - # engine_args: # vLLM engine-specific settings - # max_model_len: 4096 # Maximum sequence length the engine supports - # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - # enable_lora: true # Allow loading LoRA adapters during inference - # logprobs_mode: processed_logprobs # Logprobs mode for sampling results - # device_group: # Logical device group for the sampler - # name: sampler - # ranks: 1 # Number of GPUs to use - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 # Max requests per second - # tps_limit: 100000 # Max tokens per second - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "0" + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + nproc_per_node: 1 # Number of GPU processes per node + sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + engine_args: # vLLM engine-specific settings + max_model_len: 4096 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results + device_group: # Logical device group for the sampler + name: sampler + ranks: 1 # Number of GPUs to use + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" # 4. Processor Service - name: processor diff --git a/cookbook/client/tinker/modelscope/short_math_grpo.py b/cookbook/client/tinker/modelscope/short_math_grpo.py index bf57a9424..780df4448 100644 --- a/cookbook/client/tinker/modelscope/short_math_grpo.py +++ b/cookbook/client/tinker/modelscope/short_math_grpo.py @@ -165,7 +165,7 @@ def main(): if step % SYNC_INTERVAL == 0: logger.info(f'Step {step}: Saving weights for sampler...') - sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}')) + sampling_client = training_client.save_weights_and_get_sampling_client() logger.info(f'Step {step}: Sampling client ready') if sampling_client is None: diff --git a/cookbook/client/tinker/self_host/short_math_grpo.py b/cookbook/client/tinker/self_host/short_math_grpo.py index e19955a6c..f87fe812e 100644 --- a/cookbook/client/tinker/self_host/short_math_grpo.py +++ b/cookbook/client/tinker/self_host/short_math_grpo.py @@ -165,7 +165,7 @@ def main(): if step % SYNC_INTERVAL == 0: logger.info(f'Step {step}: Saving weights for sampler...') - sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}')) + sampling_client = training_client.save_weights_and_get_sampling_client() logger.info(f'Step {step}: Sampling client ready') if sampling_client is None: diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index b6973db04..1ddc87489 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -272,11 +272,15 @@ async def _do_save_for_sampler(): if metadata.get('base_model'): payload['base_model'] = metadata['base_model'] sampling_session_id = await self.state.create_sampling_session(payload) - # Return ``tinker_path`` (not None): tinker SDK's - # ``_save_weights_for_sampler_async`` asserts ``result.path is not None``. - # ``sampling_session_id`` is still the canonical handle. - return types.SaveWeightsForSamplerResponseInternal( - path=tinker_path, sampling_session_id=sampling_session_id) + # Tinker SDK distinguishes two modes by whether sampling_session_seq_id is set: + # 1. save_weights_for_sampler(name): expects path != None + # 2. save_weights_and_get_sampling_client(): expects path == None, sampling_session_id != None + if body.sampling_session_seq_id is not None: + return types.SaveWeightsForSamplerResponseInternal( + path=None, sampling_session_id=sampling_session_id) + else: + return types.SaveWeightsForSamplerResponseInternal( + path=tinker_path, sampling_session_id=sampling_session_id) except Exception: logger.error(traceback.format_exc()) return types.RequestFailedResponse( diff --git a/tests/server/integration/test_mock_mode_startup.py b/tests/server/integration/test_mock_mode_startup.py index 30eaf2df6..0e7936437 100644 --- a/tests/server/integration/test_mock_mode_startup.py +++ b/tests/server/integration/test_mock_mode_startup.py @@ -339,6 +339,8 @@ def _exercise_tinker_client(base: str) -> None: ).result() sampler_ckpt = training.save_weights_for_sampler(name='step-1').result() + # path mode: save_weights_for_sampler(name) must return path != None + assert sampler_ckpt.path is not None # Gateway's /asample resolves ``base_model`` from ``body.base_model`` or # ``sampling_session_id``; pass it explicitly because the SDK only sets # ``model_path`` and the gateway doesn't parse ``twinkle://`` URIs. @@ -349,6 +351,15 @@ def _exercise_tinker_client(base: str) -> None: sampling_params=types.SamplingParams(max_tokens=4), ).result() + # sampling_session_seq_id mode: save_weights_and_get_sampling_client() + # must return path == None and sampling_session_id != None (asserted by SDK internally) + sampling_client = training.save_weights_and_get_sampling_client() + sampling_client.sample( + prompt=types.ModelInput.from_ints([1, 2, 3]), + num_samples=1, + sampling_params=types.SamplingParams(max_tokens=4), + ).result() + # ``TrainingClient`` exposes save_state/load_state, not save_weights — # the wire-level handler is /tinker/save_weights either way. ckpt = training.save_state(name='step-2').result() diff --git a/tests/server/model/test_tinker_handlers.py b/tests/server/model/test_tinker_handlers.py index 72efbaf90..a5330172f 100644 --- a/tests/server/model/test_tinker_handlers.py +++ b/tests/server/model/test_tinker_handlers.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import AsyncMock, MagicMock, patch from fastapi import FastAPI from starlette.requests import Request from tinker import types @@ -46,3 +47,81 @@ async def test_tinker_dpo_forward_backward_requires_per_dp_pairs(): assert management.scheduled[-1]['batch_size'] == 2 assert management.scheduled[-1]['data_world_size'] == 2 assert management.scheduled[-1]['batch_size_multiple'] == 2 + + +class _SaveWeightsDummyManagement: + """Dummy management that actually executes the task to test save_weights_for_sampler logic.""" + + def __init__(self): + self.model = MagicMock() + self.state = MagicMock() + self.state.get_model_metadata = AsyncMock(return_value={'base_model': 'test-model'}) + self.state.create_sampling_session = AsyncMock(return_value='session-123') + + async def _on_request_start(self, request): + return 'token1' + + def get_adapter_name(self, adapter_name=None): + return adapter_name + + def assert_resource_exists(self, adapter_name): + pass + + async def schedule_task(self, task, **kwargs): + # Actually execute the task to test response logic + return await task() + + +@pytest.mark.asyncio +@patch('twinkle.server.model.tinker_handlers.create_checkpoint_manager') +async def test_save_weights_for_sampler_path_mode_returns_path(mock_create_ckpt_mgr): + """save_weights_for_sampler(name) mode: sampling_session_seq_id is None → returns path != None.""" + mock_ckpt_mgr = MagicMock() + mock_ckpt_mgr.get_ckpt_name.return_value = 'step-1' + mock_ckpt_mgr.get_save_dir.return_value = '/tmp/save_dir' + mock_ckpt_mgr.save.return_value = 'twinkle://model1/sampler_weights/20260101_000000' + mock_create_ckpt_mgr.return_value = mock_ckpt_mgr + + management = _SaveWeightsDummyManagement() + app = FastAPI() + _register_tinker_routes(app, lambda: management) + + body = types.SaveWeightsForSamplerRequest( + model_id='model1', + path='step-1', + sampling_session_seq_id=None, # path mode + ) + + route = next(route for route in app.routes if getattr(route, 'path', None) == '/tinker/save_weights_for_sampler') + request = Request({'type': 'http', 'headers': []}) + response = await route.endpoint(request, body, management) + + assert response.path == 'twinkle://model1/sampler_weights/20260101_000000' + assert response.sampling_session_id == 'session-123' + + +@pytest.mark.asyncio +@patch('twinkle.server.model.tinker_handlers.create_checkpoint_manager') +async def test_save_weights_for_sampler_session_mode_returns_none_path(mock_create_ckpt_mgr): + """save_weights_and_get_sampling_client() mode: sampling_session_seq_id is set → returns path == None.""" + mock_ckpt_mgr = MagicMock() + mock_ckpt_mgr.get_ckpt_name.return_value = 'step-1' + mock_ckpt_mgr.get_save_dir.return_value = '/tmp/save_dir' + mock_ckpt_mgr.save.return_value = 'twinkle://model1/sampler_weights/20260101_000000' + mock_create_ckpt_mgr.return_value = mock_ckpt_mgr + + management = _SaveWeightsDummyManagement() + app = FastAPI() + _register_tinker_routes(app, lambda: management) + + body = types.SaveWeightsForSamplerRequest( + model_id='model1', + sampling_session_seq_id=0, # session mode + ) + + route = next(route for route in app.routes if getattr(route, 'path', None) == '/tinker/save_weights_for_sampler') + request = Request({'type': 'http', 'headers': []}) + response = await route.endpoint(request, body, management) + + assert response.path is None + assert response.sampling_session_id == 'session-123' From f9cb736350f76f62776e3c19b3cb8bffae12f916 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 24 Jun 2026 20:40:47 +0800 Subject: [PATCH 02/16] fix: add NCCL-safe utilities and GRPO backend failure isolation --- .../client/server/megatron/server_config.yaml | 14 +- .../server/megatron/server_config_4b.yaml | 8 + .../server/transformer/server_config.yaml | 8 + .../server/transformer/server_config_e2e.yaml | 8 + server_config_4b_e2e.yaml | 126 ++++ src/twinkle/model/megatron/megatron.py | 1 + src/twinkle/model/optimizer_group.py | 6 + .../model/transformers/transformers.py | 2 + src/twinkle/processor/base.py | 2 +- .../server/launcher/env_propagation.py | 10 + .../model/backends/transformers_model.py | 3 + src/twinkle/utils/nccl_safe.py | 226 +++++++ ...ll_cycle_e2e.py => test_full_cycle_e2e.py} | 28 +- .../integration/test_nccl_safe_tinker_e2e.py | 438 ++++++++++++ .../integration/test_nccl_safe_twinkle_e2e.py | 345 ++++++++++ tests/server/test_nccl_safe.py | 623 ++++++++++++++++++ 16 files changed, 1838 insertions(+), 10 deletions(-) create mode 100644 server_config_4b_e2e.yaml create mode 100644 src/twinkle/utils/nccl_safe.py rename tests/server/integration/{full_cycle_e2e.py => test_full_cycle_e2e.py} (95%) create mode 100644 tests/server/integration/test_nccl_safe_tinker_e2e.py create mode 100644 tests/server/integration/test_nccl_safe_twinkle_e2e.py create mode 100644 tests/server/test_nccl_safe.py diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index c93680066..16fdeff45 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -54,6 +54,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" + TWINKLE_FAIL_FAST: "1" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -66,7 +67,7 @@ applications: nproc_per_node: 4 # Number of GPU processes per node sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) engine_args: # vLLM engine-specific settings - max_model_len: 65536 # Maximum sequence length the engine supports + max_model_len: 32768 # Maximum sequence length the engine supports gpu_memory_utilization: 0.75 # 80% utilization, ~64GB/GPU, leaves buffer for safety enable_lora: true # Allow loading LoRA adapters during inference max_loras: 5 # Max allowed loras working on vLLM at the same time @@ -84,7 +85,7 @@ applications: queue_config: rps_limit: 20 # Max requests per second tps_limit: 131072 # Max tokens per second - max_input_tokens: 65536 + max_input_tokens: 32768 deployments: - name: SamplerManagement autoscaling_config: @@ -97,6 +98,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" + TWINKLE_FAIL_FAST: "1" # 2. Model Service - Hosts the base model for training. # Config: PP=2 x DP=2 on 4 GPUs, ~27GB weights/GPU, comfortable for LoRA training @@ -106,7 +108,7 @@ applications: args: backend: megatron # Use Megatron-LM backend model_id: "ms://Qwen/Qwen3.6-27B" # ModelScope model identifier - max_length: 65536 # model max length + max_length: 32768 # model max length max_loras: 3 # model max loras nproc_per_node: 4 # Number of GPU processes per node device_group: @@ -121,7 +123,7 @@ applications: queue_config: rps_limit: 20 # Max requests per second tps_limit: 131072 # Max tokens per second - max_input_tokens: 65536 + max_input_tokens: 32768 adapter_config: adapter_timeout: 120 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) @@ -137,6 +139,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" + TWINKLE_FAIL_FAST: "1" # 4. Processor Service - name: processor @@ -159,3 +162,6 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index 7eed4699d..b70768e8d 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -31,6 +31,9 @@ applications: target_ongoing_requests: 128 # Target concurrent requests per replica ray_actor_options: num_cpus: 0.1 # CPU resources allocated to this actor + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. @@ -68,6 +71,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_FAIL_FAST: "1" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -105,6 +109,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_FAIL_FAST: "1" # 4. Processor Service - name: processor @@ -127,3 +132,6 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index b5d8497fd..4ee81966b 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -49,6 +49,9 @@ applications: target_ongoing_requests: 128 # Target concurrent requests per replica ray_actor_options: num_cpus: 0.1 # CPU resources allocated to this actor + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" # 2. Model Service - Hosts the base model for training. - name: models-Qwen3.5-4B @@ -82,6 +85,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -118,6 +122,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" # 4. Processor Service - name: processor @@ -140,3 +145,6 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" diff --git a/cookbook/client/server/transformer/server_config_e2e.yaml b/cookbook/client/server/transformer/server_config_e2e.yaml index 8c3b7cf05..6051ec840 100644 --- a/cookbook/client/server/transformer/server_config_e2e.yaml +++ b/cookbook/client/server/transformer/server_config_e2e.yaml @@ -32,6 +32,9 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" - name: models-Qwen3.5-4B route_prefix: /api/v1/model/Qwen/Qwen3.5-4B @@ -64,6 +67,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" - name: sampler-Qwen3.5-4B route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B @@ -98,6 +102,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" - name: processor route_prefix: /api/v1/processor @@ -119,3 +124,6 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" diff --git a/server_config_4b_e2e.yaml b/server_config_4b_e2e.yaml new file mode 100644 index 000000000..8221f1ea6 --- /dev/null +++ b/server_config_4b_e2e.yaml @@ -0,0 +1,126 @@ +# Twinkle Server Configuration - E2E Test (4B model, FAIL_FAST=1) + +proxy_location: EveryNode + +http_options: + host: 0.0.0.0 + port: 9000 + +applications: + + - name: server + route_prefix: /api/v1 + import_path: server + args: + server_config: + per_token_model_limit: 3 + supported_models: + - Qwen/Qwen3.5-4B + deployments: + - name: TinkerCompatServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" + + - name: models-Qwen3.5-4B + route_prefix: /api/v1/model/Qwen/Qwen3.5-4B + import_path: model + args: + backend: megatron + model_id: "ms://Qwen/Qwen3.5-4B" + max_length: 10240 + nproc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 + queue_config: + rps_limit: 100 + tps_limit: 100000 + max_input_tokens: 60000 + adapter_config: + adapter_timeout: 30 + adapter_max_lifetime: 36000 + max_loras: 5 + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_FAIL_FAST: "1" + + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" + nproc_per_node: 1 + sampler_type: vllm + engine_args: + max_model_len: 16000 + gpu_memory_utilization: 0.7 + enable_lora: true + logprobs_mode: processed_logprobs + enable_tower_connector_lora: true + device_group: + name: sampler + ranks: 1 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_FAIL_FAST: "1" + + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index b051546bb..ab5f936da 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -407,6 +407,7 @@ def forward_step_func(data_iterator, model): output_tensor = model(**batch) else: output_tensor = model(**batch) + batch['labels'] = labels logps = None unpacked_logits = None diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py index 5fdb89f74..384dffe42 100644 --- a/src/twinkle/model/optimizer_group.py +++ b/src/twinkle/model/optimizer_group.py @@ -47,6 +47,12 @@ class BaseOptimizerGroup: _device_mesh: DeviceMesh = None _last_grad_norm: float = 0.0 + def __setattr__(self, name, value): + if name == 'loss_instance' and value is not None: + from twinkle.utils.nccl_safe import safe_loss + value = safe_loss(value) + super().__setattr__(name, value) + def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: if gradient_accumulation_steps is None: gradient_accumulation_steps = self.gradient_accumulation_steps diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 61733d7dc..6abcaf3cc 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -42,6 +42,7 @@ from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +from twinkle.utils.nccl_safe import nccl_safe from twinkle.utils.transformers_utils import filter_from_config_kwargs logger = get_logger() @@ -621,6 +622,7 @@ def backward(self, **kwargs): optimizer_config.train_status.loss_value = None @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + @nccl_safe def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Do forward, calculate loss, and backward. diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 467355e51..63d80bb93 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -93,7 +93,7 @@ def prepare_outputs(self, inputs: List[InputFeature], **kwargs) -> Union[List[In return inputs[0] else: for _input in inputs: - if 'position_ids' in _input and _input['position_ids'].dim() > 2: + if 'position_ids' in _input and _input['position_ids'] is not None and _input['position_ids'].dim() > 2: # megatron needs 3, 1, N _input['position_ids'] = _input['position_ids'][1:] return inputs diff --git a/src/twinkle/server/launcher/env_propagation.py b/src/twinkle/server/launcher/env_propagation.py index 4d4a7a474..3b0c9b956 100644 --- a/src/twinkle/server/launcher/env_propagation.py +++ b/src/twinkle/server/launcher/env_propagation.py @@ -21,6 +21,10 @@ 'TWINKLE_MODEL_ID_ALIASES', ) +# NCCL-safe env var keys: controls fault tolerance behavior in distributed +# training (safe_loss / @nccl_safe). Must reach model worker actors. +NCCL_SAFE_ENV_KEYS: tuple[str, ...] = ('TWINKLE_FAIL_FAST', ) + def build_telemetry_env_vars() -> dict[str, str]: """Collect telemetry env vars from ``os.environ`` for worker propagation.""" @@ -33,9 +37,15 @@ def build_persistence_env_vars() -> dict[str, str]: return {k: os.environ[k] for k in PERSISTENCE_ENV_KEYS if k in os.environ} +def build_nccl_safe_env_vars() -> dict[str, str]: + """Collect NCCL-safe env vars from ``os.environ`` for worker propagation.""" + return {k: os.environ[k] for k in NCCL_SAFE_ENV_KEYS if k in os.environ} + + def build_propagated_env_vars() -> dict[str, str]: """Aggregate all env vars that must reach Ray worker processes.""" merged: dict[str, str] = {} merged.update(build_telemetry_env_vars()) merged.update(build_persistence_env_vars()) + merged.update(build_nccl_safe_env_vars()) return merged diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index 0e8730833..5b90c202c 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -16,6 +16,7 @@ from twinkle.server.common.datum import datum_to_input_feature, extract_rl_features_for_loss from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results, to_cpu_safe_output) +from twinkle.utils.nccl_safe import nccl_safe @remote_class() @@ -40,6 +41,7 @@ def tinker_forward_only(self, *, inputs: list[types.Datum], adapter_name: str = return [results, 0.0] @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) + @nccl_safe(tinker=True) def tinker_forward_backward(self, *, inputs: list[types.Datum], adapter_name: str, loss_fn: str, **kwargs): self._tinker_setup_loss(loss_fn, inputs, adapter_name, kwargs) template = self.get_template(adapter_name) @@ -98,6 +100,7 @@ def forward_only(self, *, inputs: InputFeature | list[InputFeature] | Trajectory return to_cpu_safe_output(output) @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + @nccl_safe def forward_backward(self, *, inputs: InputFeature | list[InputFeature] | Trajectory | list[Trajectory], **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) diff --git a/src/twinkle/utils/nccl_safe.py b/src/twinkle/utils/nccl_safe.py new file mode 100644 index 000000000..451e936c9 --- /dev/null +++ b/src/twinkle/utils/nccl_safe.py @@ -0,0 +1,226 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""NCCL-safe utilities for production distributed training. + +Provides two layers of protection to prevent NCCL hangs: + +Layer 1 - safe_loss(): + Wraps loss instances to catch computation errors and return + graph-connected zero loss (ensures FSDP ReduceScatter can proceed). + +Layer 2 - @nccl_safe decorator: + Wraps forward_backward methods to ensure backward() always executes + after forward() has started, even if intermediate code raises. + +Controlled by environment variable: + TWINKLE_FAIL_FAST=1 (default, development): all protection is transparent, + exceptions propagate normally. + TWINKLE_FAIL_FAST=0 (production): protection activated, exceptions in + NCCL-critical sections are caught and handled gracefully. +""" +import functools +import os + +from twinkle.data_format import LossOutput +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +def _is_fail_fast() -> bool: + """Check if fail-fast mode is enabled (default: enabled). + + Returns True (fail-fast/development mode) unless TWINKLE_FAIL_FAST + is explicitly set to a falsy value. + """ + val = os.getenv('TWINKLE_FAIL_FAST', '1').upper() + return val not in ('0', 'NO', 'FALSE', 'OFF') + + +# ─── Layer 1: safe_loss ──────────────────────────────────────────────────── + + +def safe_loss(loss_instance): + """Wrap loss instance for production graceful degradation. + + Always wraps the loss instance (idempotent). The fail-fast check is deferred + to call time so that TWINKLE_FAIL_FAST can be set after wrapping (e.g. in + Ray actor processes where env vars may not be inherited from the launcher). + + When TWINKLE_FAIL_FAST=1 (default, development): wrapper is transparent, + exceptions propagate normally. + When TWINKLE_FAIL_FAST=0 (production): wrapper catches exceptions and + returns a graph-connected zero loss (ensures FSDP ReduceScatter proceeds). + + Idempotent: already-wrapped instances are returned as-is. + """ + if getattr(loss_instance, '_nccl_safe_wrapped', False): + return loss_instance + + @functools.wraps(type(loss_instance).__call__) + def wrapper(inputs, outputs, **kwargs): + if _is_fail_fast(): + return loss_instance(inputs, outputs, **kwargs) + try: + return loss_instance(inputs, outputs, **kwargs) + except Exception as e: + logger.warning(f'[nccl_safe] Loss computation skipped due to error: ' + f'{type(e).__name__}: {e}') + return _zero_loss(outputs) + + # Forward known loss attributes + wrapper.require_logps = getattr(loss_instance, 'require_logps', True) + wrapper.require_entropy = getattr(loss_instance, 'require_entropy', False) + wrapper.require_logits = getattr(loss_instance, 'require_logits', False) + wrapper._nccl_safe_wrapped = True + return wrapper + + +def _zero_loss(outputs) -> 'LossOutput': + """Create a graph-connected zero loss for FSDP compatibility. + + Finds a gradient-bearing tensor from outputs to maintain graph connectivity, + ensuring backward hooks (ReduceScatter) fire. + """ + import torch + if isinstance(outputs, dict): + for key in ('logps', 'logits', 'loss'): + t = outputs.get(key) + if t is not None and isinstance(t, torch.Tensor) and t.requires_grad: + return LossOutput(loss=(t.flatten()[:1] * 0).sum(), num_tokens=0) + # Fallback: standalone zero tensor (may not trigger FSDP hooks) + device = 'cpu' + if isinstance(outputs, dict): + for v in outputs.values(): + if hasattr(v, 'device'): + device = v.device + break + return LossOutput(loss=torch.zeros((), device=device, requires_grad=True), num_tokens=0) + + +# ─── Layer 2: @nccl_safe decorator ────────────────────────────────────────── + + +def nccl_safe(func=None, *, tinker=False): + """Decorator ensuring backward() executes if forward() has already run. + + Detects forward completion by comparing train_status.outputs before/after + the wrapped function call. If an exception occurs after forward has run + but before backward completes, forces a zero-gradient backward pass to + prevent NCCL hang (other ranks waiting for ReduceScatter). + + Args: + func: The function to decorate (when used without arguments). + tinker: If True, fallback returns ``[[], 0.0]`` (tinker format). + If False, fallback returns outputs dict with ``loss=0.0``. + + Usage:: + + @remote_function(dispatch='slice_dp', collect=...) + @nccl_safe(tinker=True) + def tinker_forward_backward(self, *, inputs, adapter_name, ...): + # method body completely unchanged + ... + + @remote_function(dispatch='slice_dp', collect=...) + @nccl_safe + def forward_backward(self, *, inputs, **kwargs): + # method body completely unchanged + ... + """ + + def decorator(fn): + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + if _is_fail_fast(): + return fn(self, *args, **kwargs) + + # Extract adapter_name for state tracking + adapter_name = kwargs.get('adapter_name') + if adapter_name is None and hasattr(self, '_get_default_group'): + adapter_name = self._get_default_group() + + og = self.optimizer_group.get(adapter_name) if adapter_name else None + if og is None: + # Cannot track state without optimizer group, passthrough + return fn(self, *args, **kwargs) + + # Snapshot state before call to detect forward completion + outputs_before = og.train_status.outputs + + try: + return fn(self, *args, **kwargs) + except Exception as e: + outputs_after = og.train_status.outputs + forward_ran = (outputs_after is not None and outputs_after is not outputs_before) + + if not forward_ran: + # Pre-forward failure: no NCCL ops started, safe to propagate + raise + + # Forward completed. Check if backward already ran. + # TransformersModel.backward() clears loss_value to None. + backward_done = (og.train_status.loss_value is None) + + if backward_done: + # Post-backward failure (e.g. output formatting) + # No NCCL hang risk, just return gracefully + logger.warning(f'[nccl_safe] Post-backward error (no NCCL risk): ' + f'{type(e).__name__}: {e}') + else: + # CRITICAL: forward ran but backward didn't → NCCL hang risk! + logger.warning(f'[nccl_safe] Forcing zero backward to prevent NCCL hang: ' + f'{type(e).__name__}: {e}') + _force_zero_backward(self, og, adapter_name, kwargs) + + # Return fallback result + if tinker: + return [[], 0.0] + outputs_after['loss'] = 0.0 + return outputs_after + + return wrapper + + if func is not None: + # @nccl_safe without arguments + return decorator(func) + # @nccl_safe(tinker=True) with arguments + return decorator + + +def _force_zero_backward(model, og, adapter_name, kwargs): + """Force a zero-gradient backward pass to prevent NCCL hang. + + Creates a graph-connected zero loss tensor and calls backward(), + ensuring FSDP ReduceScatter hooks fire on all ranks. + """ + import torch + + outputs = og.train_status.outputs + + # Find a graph-connected tensor for zero loss + zero_loss = None + if outputs is not None and isinstance(outputs, dict): + for key in ('logps', 'logits', 'loss'): + t = outputs.get(key) + if t is not None and isinstance(t, torch.Tensor) and t.requires_grad: + zero_loss = (t.flatten()[:1] * 0).sum() + break + + if zero_loss is None: + # Fallback: use first model parameter to maintain graph connectivity + try: + param = next(p for p in model.model.parameters() if p.requires_grad) + zero_loss = (param.flatten()[0] * 0).detach().requires_grad_(True) + except StopIteration: + device = next(model.model.parameters()).device if hasattr(model, 'model') else 'cuda' + zero_loss = torch.zeros((), device=device, requires_grad=True) + + og.train_status.loss_value = zero_loss + + # Call backward with minimal kwargs + bwd_kwargs = {'adapter_name': adapter_name} + gas = kwargs.get('gradient_accumulation_steps') + if gas is not None: + bwd_kwargs['gradient_accumulation_steps'] = gas + model.backward(**bwd_kwargs) diff --git a/tests/server/integration/full_cycle_e2e.py b/tests/server/integration/test_full_cycle_e2e.py similarity index 95% rename from tests/server/integration/full_cycle_e2e.py rename to tests/server/integration/test_full_cycle_e2e.py index 8b1b5288d..8c1a7253b 100644 --- a/tests/server/integration/full_cycle_e2e.py +++ b/tests/server/integration/test_full_cycle_e2e.py @@ -29,9 +29,8 @@ ## How to run -Not collected by pytest (no ``test_`` prefix) — needs a real GPU box -and an externally-booted server. Bring the cluster up, then run this -script directly: +Direct execution (needs a real GPU box + externally-booted server). +Bring the cluster up, then run this script directly: # 1. Boot a 3-node Ray cluster (2 GPUs for model, 1 for sampler) ray stop --force @@ -47,7 +46,11 @@ # 3. Run this script mkdir -p /tmp/twinkle_e2e_full_cycle - python -u tests/server/integration/full_cycle_e2e.py + python -u tests/server/integration/test_full_cycle_e2e.py + +Pytest execution (requires TWINKLE_TEST_GPU_E2E=1): + + TWINKLE_TEST_GPU_E2E=1 pytest tests/server/integration/test_full_cycle_e2e.py -v Expected last line: ``ALL PHASES PASSED``. Total wall time ~3 minutes (dominated by Phase A's STEPS_PHASE_A training steps). @@ -61,8 +64,15 @@ import os # noqa: E402 import sys # noqa: E402 import time # noqa: E402 + +import pytest # noqa: E402 from peft import LoraConfig # noqa: E402 +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + from twinkle import get_logger, init_twinkle_client # noqa: E402 from twinkle.dataloader import DataLoader # noqa: E402 from twinkle.dataset import Dataset, DatasetMeta # noqa: E402 @@ -72,7 +82,7 @@ logger = get_logger() BASE_MODEL = 'Qwen/Qwen3.5-4B' -BASE_URL = 'http://localhost:8000' +BASE_URL = 'http://localhost:9000' API_KEY = 'EMPTY_API_KEY' SAVE_DIR = '/tmp/twinkle_e2e_full_cycle' STEPS_PHASE_A = 100 @@ -322,5 +332,13 @@ def main() -> int: return 0 +# ── pytest entry point ── + +def test_full_cycle_e2e(): + """Pytest-collected entry point for the full-cycle E2E suite.""" + rc = main() + assert rc == 0, 'Full-cycle E2E test failed' + + if __name__ == '__main__': sys.exit(main()) diff --git a/tests/server/integration/test_nccl_safe_tinker_e2e.py b/tests/server/integration/test_nccl_safe_tinker_e2e.py new file mode 100644 index 000000000..3372c37fb --- /dev/null +++ b/tests/server/integration/test_nccl_safe_tinker_e2e.py @@ -0,0 +1,438 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Real E2E test for NCCL-safe fault tolerance via Tinker client path. + +Exercises the /tinker/forward_backward endpoint through the upstream Tinker SDK. +All adversarial scenarios verify that safe_loss catches errors gracefully without +NCCL hang or model state corruption. + +Prerequisites: + 1. Ray cluster running with GPUs (2 for model DP/TP, optionally 1 for sampler) + 2. Twinkle server started with TWINKLE_FAIL_FAST=0 + +Usage (direct): + python tests/server/integration/test_nccl_safe_tinker_e2e.py + +Usage (pytest, requires TWINKLE_TEST_GPU_E2E=1): + TWINKLE_TEST_GPU_E2E=1 pytest tests/server/integration/test_nccl_safe_tinker_e2e.py -v +""" +from __future__ import annotations + +import os +import sys +import time +import logging +import traceback + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') +logger = logging.getLogger(__name__) + + +def log(msg): + """Print + flush to avoid log suppression by init_tinker_client().""" + print(f'[E2E-Tinker] {msg}', flush=True) + + +BASE_MODEL = 'Qwen/Qwen3.5-4B' +SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +TIMEOUT = 120 + + +def wait_for_server(url, timeout=300): + """Wait for Twinkle server to become ready.""" + import requests + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f'{url}/-/routes', timeout=5) + if resp.status_code == 200: + elapsed = int(time.time() - start) + log(f'Server is ready (waited {elapsed}s)') + return True + except Exception: + pass + time.sleep(5) + raise TimeoutError(f'Server not ready after {timeout}s') + + +def init_client(): + """Initialize Tinker client and create training client.""" + os.environ['TINKER_BASE_URL'] = SERVER_URL + os.environ['TWINKLE_SERVER_TOKEN'] = 'EMPTY_TOKEN' + + from twinkle_client import init_tinker_client + init_tinker_client() + + from tinker import ServiceClient + service_client = ServiceClient() + training_client = service_client.create_lora_training_client(base_model=BASE_MODEL, rank=16) + log('Training client created successfully') + return training_client + + +def make_datum(seq_len=32, completion_len=16, *, bad_logprobs_len=None, include_advantages=True): + """Construct a Datum for GRPO training.""" + from tinker import types + + prompt_len = seq_len - completion_len + input_tokens = list(range(1, seq_len + 1)) + target_tokens = [0] * prompt_len + list(range(100, 100 + completion_len)) + weights = [0] * prompt_len + [1] * completion_len + + if bad_logprobs_len is not None: + logprobs_values = np.random.randn(bad_logprobs_len).astype(np.float32) + padded_logprobs = [0.0] * prompt_len + logprobs_values.tolist() + else: + logprobs_values = np.random.randn(completion_len).astype(np.float32) + padded_logprobs = [0.0] * prompt_len + logprobs_values.tolist() + + loss_fn_inputs = { + 'target_tokens': target_tokens, + 'weights': weights, + 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)), + } + + if include_advantages: + advantage = float(np.random.randn()) + padded_advantages = [0.0] * prompt_len + [advantage] * completion_len + loss_fn_inputs['advantages'] = types.TensorData.from_numpy( + np.array(padded_advantages, dtype=np.float32)) + + return types.Datum( + model_input=types.ModelInput.from_ints(input_tokens), + loss_fn_inputs=loss_fn_inputs, + ) + + +def run_forward_backward(training_client, datums, test_name, expect_success=True): + """Run forward_backward and return (success, result, elapsed_seconds).""" + log(f'[{test_name}] Sending {len(datums)} datums...') + start = time.time() + try: + result = training_client.forward_backward(datums, 'importance_sampling').result() + elapsed = time.time() - start + log(f'[{test_name}] Completed in {elapsed:.1f}s') + if hasattr(result, 'metrics') and result.metrics: + loss_avg = result.metrics.get('loss:avg', 'N/A') + log(f'[{test_name}] loss:avg = {loss_avg}') + return True, result, elapsed + except Exception as e: + elapsed = time.time() - start + log(f'[{test_name}] FAILED in {elapsed:.1f}s: {type(e).__name__}: {e}') + if elapsed > TIMEOUT: + log(f'[{test_name}] TIMEOUT! This suggests NCCL hang!') + return False, None, elapsed + + +def do_optim_step(training_client, test_name): + """Run optimizer step.""" + from tinker import types + try: + training_client.optim_step(types.AdamParams(learning_rate=1e-5)).result() + log(f'[{test_name}] optim_step OK') + return True + except Exception as e: + log(f'[{test_name}] optim_step FAILED: {e}') + return False + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test Scenarios (19 tests) +# ═══════════════════════════════════════════════════════════════════════════ + +def test_1_normal_grpo(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, result, elapsed = run_forward_backward(tc, datums, 'TEST-1-NORMAL') + assert ok and elapsed < TIMEOUT + do_optim_step(tc, 'TEST-1-NORMAL') + return True + +def test_2_bad_old_logps(tc): + datums = [ + make_datum(seq_len=64, completion_len=32), + make_datum(seq_len=64, completion_len=32, bad_logprobs_len=5), + make_datum(seq_len=64, completion_len=32), + make_datum(seq_len=64, completion_len=32, bad_logprobs_len=99), + ] + ok, result, elapsed = run_forward_backward(tc, datums, 'TEST-2-BAD-LOGPS') + if not ok: + return elapsed < TIMEOUT + assert elapsed < TIMEOUT + do_optim_step(tc, 'TEST-2-BAD-LOGPS') + return True + +def test_3_recovery(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-3-RECOVERY') + assert ok and elapsed < TIMEOUT + do_optim_step(tc, 'TEST-3-RECOVERY') + return True + +def test_4_no_advantages(tc): + datums = [make_datum(seq_len=64, completion_len=32, include_advantages=False) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-4-NO-ADV') + assert ok and elapsed < TIMEOUT + do_optim_step(tc, 'TEST-4-NO-ADV') + return True + +def test_5_consecutive_bad(tc): + for i in range(5): + datums = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=3+i) for _ in range(2)] + _, _, elapsed = run_forward_backward(tc, datums, f'TEST-5-{i+1}') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, f'TEST-5-{i+1}') + return True + +def test_6_nan_logprobs(tc): + from tinker import types + datums = [] + for _ in range(4): + d = make_datum(seq_len=64, completion_len=32) + d.loss_fn_inputs['logprobs'] = types.TensorData.from_numpy( + np.array([float('nan')] * 64, dtype=np.float32)) + datums.append(d) + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-6-NAN') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-6-NAN') + return True + +def test_7_inf_logprobs(tc): + from tinker import types + datums = [] + for _ in range(4): + d = make_datum(seq_len=64, completion_len=32) + inf_arr = np.full(64, float('inf'), dtype=np.float32) + inf_arr[::2] = float('-inf') + d.loss_fn_inputs['logprobs'] = types.TensorData.from_numpy(inf_arr) + datums.append(d) + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-7-INF') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-7-INF') + return True + +def test_8_extreme_advantages(tc): + from tinker import types + datums = [] + for i in range(4): + d = make_datum(seq_len=64, completion_len=32) + val = 1e30 if i % 2 == 0 else -1e30 + adv = np.full(64, 0.0, dtype=np.float32) + adv[32:] = val + d.loss_fn_inputs['advantages'] = types.TensorData.from_numpy(adv) + datums.append(d) + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-8-EXTREME-ADV') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-8-EXTREME-ADV') + return True + +def test_9_zero_completion(tc): + from tinker import types + datums = [] + for _ in range(4): + d = types.Datum( + model_input=types.ModelInput.from_ints(list(range(1, 65))), + loss_fn_inputs={ + 'target_tokens': [0]*64, 'weights': [0]*64, + 'logprobs': types.TensorData.from_numpy(np.zeros(64, dtype=np.float32)), + 'advantages': types.TensorData.from_numpy(np.zeros(64, dtype=np.float32)), + }, + ) + datums.append(d) + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-9-ZERO-COMPL') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-9-ZERO-COMPL') + return True + +def test_10_partial_advantages(tc): + datums = [ + make_datum(seq_len=64, completion_len=32, include_advantages=True), + make_datum(seq_len=64, completion_len=32, include_advantages=False), + make_datum(seq_len=64, completion_len=32, include_advantages=True), + make_datum(seq_len=64, completion_len=32, include_advantages=False), + ] + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-10-PARTIAL-ADV') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-10-PARTIAL-ADV') + return True + +def test_11_mixed_seq_lengths(tc): + datums = [ + make_datum(seq_len=32, completion_len=16), + make_datum(seq_len=128, completion_len=64), + make_datum(seq_len=48, completion_len=24), + make_datum(seq_len=96, completion_len=48), + ] + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-11-MIXED') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-11-MIXED') + return True + +def test_12_all_bad(tc): + datums = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=i) for i in range(4)] + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-12-ALL-BAD') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-12-ALL-BAD') + return True + +def test_13_forward_only_then_train(tc): + datums_infer = [make_datum(seq_len=64, completion_len=32, include_advantages=False) for _ in range(4)] + start = time.time() + try: + tc.forward(datums_infer).result() + except Exception: + if time.time() - start >= TIMEOUT: + return False + datums_train = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums_train, 'TEST-13-TRAIN') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-13-TRAIN') + return True + +def test_14_rapid_bad_good(tc): + for i in range(5): + bad = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=i+1) for _ in range(4)] + _, _, elapsed = run_forward_backward(tc, bad, f'TEST-14-BAD-{i+1}') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, f'TEST-14-BAD-{i+1}') + good = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, good, f'TEST-14-GOOD-{i+1}') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(tc, f'TEST-14-GOOD-{i+1}') + return True + +def test_15_final_health(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-15-FINAL') + assert ok and elapsed < TIMEOUT + do_optim_step(tc, 'TEST-15-FINAL') + return True + +def test_16_large_batch(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(16)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-16-LARGE') + if elapsed >= TIMEOUT: + return False + assert ok + do_optim_step(tc, 'TEST-16-LARGE') + return True + +def test_17_single_datum(tc): + datums = [make_datum(seq_len=64, completion_len=32)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-17-SINGLE') + if elapsed >= TIMEOUT: + return False + assert ok + do_optim_step(tc, 'TEST-17-SINGLE') + return True + +def test_18_save_after_error(tc): + bad = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=2) for _ in range(2)] + _, _, elapsed = run_forward_backward(tc, bad, 'TEST-18-ERR') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-18-ERR') + try: + tc.save_weights_for_sampler().result() + except Exception: + pass + good = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, good, 'TEST-18-POST') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-18-POST') + return True + +def test_19_consecutive_optim_steps(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-19-BASE') + assert ok and elapsed < TIMEOUT + for i in range(3): + do_optim_step(tc, f'TEST-19-STEP-{i+1}') + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-19-VERIFY') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-19-VERIFY') + return True + + +ALL_TESTS = [ + ('TEST-1: Normal GRPO Training', test_1_normal_grpo), + ('TEST-2: Bad old_logps (original bug)', test_2_bad_old_logps), + ('TEST-3: Recovery after error', test_3_recovery), + ('TEST-4: No advantages (zero loss)', test_4_no_advantages), + ('TEST-5: Consecutive bad batches', test_5_consecutive_bad), + ('TEST-6: NaN logprobs', test_6_nan_logprobs), + ('TEST-7: +Inf/-Inf logprobs', test_7_inf_logprobs), + ('TEST-8: Extreme advantages (1e30)', test_8_extreme_advantages), + ('TEST-9: Zero completion tokens', test_9_zero_completion), + ('TEST-10: Partial advantages (ragged)', test_10_partial_advantages), + ('TEST-11: Mixed sequence lengths', test_11_mixed_seq_lengths), + ('TEST-12: All datums bad (100%)', test_12_all_bad), + ('TEST-13: forward_only then train', test_13_forward_only_then_train), + ('TEST-14: Rapid bad->good alternation', test_14_rapid_bad_good), + ('TEST-15: Final health check', test_15_final_health), + ('TEST-16: Large batch (16 datums)', test_16_large_batch), + ('TEST-17: Single datum batch', test_17_single_datum), + ('TEST-18: Save after error', test_18_save_after_error), + ('TEST-19: Consecutive optim_steps', test_19_consecutive_optim_steps), +] + + +def main(): + log('=' * 60) + log('NCCL-Safe E2E Test - Tinker Client Path') + log('=' * 60) + log(f'Server URL: {SERVER_URL}') + log(f'Base Model: {BASE_MODEL}') + log(f'TWINKLE_FAIL_FAST = {os.getenv("TWINKLE_FAIL_FAST", "1 (default)")}') + + wait_for_server(SERVER_URL) + tc = init_client() + + results = [] + for name, test_fn in ALL_TESTS: + log(f'\n{"=" * 60}\n{name}\n{"=" * 60}') + try: + passed = test_fn(tc) + results.append((name, 'PASS' if passed else 'FAIL')) + log(f'[{name}] {"PASS" if passed else "FAIL"}') + except Exception as e: + log(f'{name}: EXCEPTION: {e}') + traceback.print_exc() + results.append((name, 'FAIL')) + + log(f'\n{"=" * 60}\nRESULTS SUMMARY\n{"=" * 60}') + all_passed = all(s == 'PASS' for _, s in results) + for name, status in results: + log(f' [{status}] {name}') + log(f'\n{"ALL" if all_passed else "SOME"} {len(results)} TESTS {"PASSED" if all_passed else "FAILED"}!') + return 0 if all_passed else 1 + + +def test_nccl_safe_tinker_e2e(): + """Pytest-collected entry point.""" + rc = main() + assert rc == 0, 'Some Tinker NCCL-safe E2E tests failed' + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/integration/test_nccl_safe_twinkle_e2e.py b/tests/server/integration/test_nccl_safe_twinkle_e2e.py new file mode 100644 index 000000000..e8e3d6cc5 --- /dev/null +++ b/tests/server/integration/test_nccl_safe_twinkle_e2e.py @@ -0,0 +1,345 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Real E2E test for NCCL-safe fault tolerance via Twinkle client path. + +Exercises the /twinkle/forward_backward endpoint through the Twinkle SDK +(init_twinkle_client + MultiLoraTransformersModel). This is a SEPARATE code +path from the Tinker SDK (/tinker/forward_backward). + +Prerequisites: + 1. Ray cluster running with GPUs (2 for model DP/TP, optionally 1 for sampler) + 2. Twinkle server started with TWINKLE_FAIL_FAST=0 + +Usage (direct): + python tests/server/integration/test_nccl_safe_twinkle_e2e.py + +Usage (pytest, requires TWINKLE_TEST_GPU_E2E=1): + TWINKLE_TEST_GPU_E2E=1 pytest tests/server/integration/test_nccl_safe_twinkle_e2e.py -v +""" +from __future__ import annotations + +import os +import sys +import time +import logging +import traceback +from typing import Any, Dict, List + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') +logger = logging.getLogger(__name__) + + +def log(msg): + print(f'[E2E-Twinkle] {msg}', flush=True) + + +BASE_MODEL = 'Qwen/Qwen3.5-4B' +SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +TIMEOUT = 120 +ADAPTER_NAME = 'nccl-safe-test' + + +def wait_for_server(url, timeout=300): + """Wait for Twinkle server to become ready.""" + import requests + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f'{url}/-/routes', timeout=5) + if resp.status_code == 200: + log(f'Server is ready (waited {int(time.time() - start)}s)') + return True + except Exception: + pass + time.sleep(5) + raise TimeoutError(f'Server not ready after {timeout}s') + + +def init_client(): + """Initialize Twinkle client and configure model for GRPO training.""" + from twinkle_client import init_twinkle_client + from twinkle_client.model import MultiLoraTransformersModel + from peft import LoraConfig + + init_twinkle_client(base_url=SERVER_URL, api_key='EMPTY_TOKEN') + + model = MultiLoraTransformersModel(model_id=f'ms://{BASE_MODEL}') + model.add_adapter_to_model( + adapter_name=ADAPTER_NAME, + config=LoraConfig(r=16, target_modules=['q_proj', 'v_proj']), + gradient_accumulation_steps=1, + ) + model.set_loss('GRPOLoss', init_args={'epsilon': 0.2}) + model.set_optimizer('Adam', lr=1e-5) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + log('Twinkle client + model configured successfully') + return model + + +def make_input_features( + batch_size=4, seq_len=64, completion_len=32, *, + bad_old_logps_len=None, include_advantages=True, + nan_old_logps=False, extreme_advantages=None, all_labels_masked=False, +): + """Construct InputFeature list + old_logps + advantages for GRPO.""" + prompt_len = seq_len - completion_len + input_features = [] + old_logps_list = [] + advantages_list = [] + + for i in range(batch_size): + input_ids = list(range(1, seq_len + 1)) + labels = [-100] * seq_len if all_labels_masked else ( + [-100] * prompt_len + list(range(100, 100 + completion_len))) + input_features.append({ + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': [1] * seq_len, + 'position_ids': list(range(seq_len)), + }) + + if bad_old_logps_len is not None: + logps = np.random.randn(bad_old_logps_len).tolist() + elif nan_old_logps: + logps = [float('nan')] * completion_len + else: + logps = np.random.randn(completion_len).tolist() + old_logps_list.append(logps) + + if extreme_advantages is not None: + advantages_list.append(extreme_advantages if i % 2 == 0 else -extreme_advantages) + else: + advantages_list.append(float(np.random.randn())) + + old_logps = old_logps_list if include_advantages else None + advantages = advantages_list if include_advantages else None + return input_features, old_logps, advantages + + +def run_forward_backward(model, inputs, old_logps, advantages, test_name): + """Run forward_backward and return (success, result, elapsed_seconds).""" + log(f'[{test_name}] Sending {len(inputs)} input features...') + start = time.time() + try: + kwargs: Dict[str, Any] = {} + if old_logps is not None: + kwargs['old_logps'] = old_logps + if advantages is not None: + kwargs['advantages'] = advantages + + result = model.forward_backward(inputs=inputs, **kwargs) + elapsed = time.time() - start + log(f'[{test_name}] Completed in {elapsed:.1f}s') + if hasattr(result, 'result') and result.result is not None: + log(f'[{test_name}] result = {result.result}') + return True, result, elapsed + except Exception as e: + elapsed = time.time() - start + log(f'[{test_name}] FAILED in {elapsed:.1f}s: {type(e).__name__}: {e}') + if elapsed > TIMEOUT: + log(f'[{test_name}] TIMEOUT! This suggests NCCL hang!') + return False, None, elapsed + + +def do_optim_step(model, test_name): + """Run clip_grad_and_step.""" + try: + model.clip_grad_and_step() + log(f'[{test_name}] clip_grad_and_step OK') + return True + except Exception as e: + log(f'[{test_name}] clip_grad_and_step FAILED: {e}') + return False + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test Scenarios (12 tests) +# ═══════════════════════════════════════════════════════════════════════════ + +def test_1_normal_grpo(m): + inputs, old_logps, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-1-NORMAL') + assert ok and elapsed < TIMEOUT + do_optim_step(m, 'TEST-1-NORMAL') + return True + +def test_2_bad_old_logps(m): + inputs, old_logps, adv = make_input_features(batch_size=4, bad_old_logps_len=5) + ok, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-2-BAD-LOGPS') + if not ok: + return elapsed < TIMEOUT + assert elapsed < TIMEOUT + do_optim_step(m, 'TEST-2-BAD-LOGPS') + return True + +def test_3_recovery(m): + inputs, old_logps, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-3-RECOVERY') + assert ok and elapsed < TIMEOUT + do_optim_step(m, 'TEST-3-RECOVERY') + return True + +def test_4_nan_old_logps(m): + inputs, old_logps, adv = make_input_features(batch_size=4, nan_old_logps=True) + _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-4-NAN') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-4-NAN') + return True + +def test_5_extreme_advantages(m): + inputs, old_logps, adv = make_input_features(batch_size=4, extreme_advantages=1e30) + _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-5-EXTREME') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-5-EXTREME') + return True + +def test_6_all_labels_masked(m): + inputs, old_logps, adv = make_input_features(batch_size=4, all_labels_masked=True) + _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-6-MASKED') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-6-MASKED') + return True + +def test_7_consecutive_bad(m): + for i in range(5): + inputs, old_logps, adv = make_input_features(batch_size=2, bad_old_logps_len=i+1) + _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, f'TEST-7-{i+1}') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, f'TEST-7-{i+1}') + return True + +def test_8_rapid_bad_good(m): + for i in range(5): + bad_in, bad_lp, bad_adv = make_input_features(batch_size=4, bad_old_logps_len=i+1) + _, _, elapsed = run_forward_backward(m, bad_in, bad_lp, bad_adv, f'TEST-8-BAD-{i+1}') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, f'TEST-8-BAD-{i+1}') + good_in, good_lp, good_adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, good_in, good_lp, good_adv, f'TEST-8-GOOD-{i+1}') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(m, f'TEST-8-GOOD-{i+1}') + return True + +def test_9_final_health(m): + inputs, old_logps, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-9-FINAL') + assert ok and elapsed < TIMEOUT + do_optim_step(m, 'TEST-9-FINAL') + return True + +def test_10_gradient_accumulation_error(m): + inputs, lp, adv = make_input_features(batch_size=2) + ok, _, elapsed = run_forward_backward(m, inputs, lp, adv, 'TEST-10-GA1') + if not ok or elapsed >= TIMEOUT: + return False + bad_in, bad_lp, bad_adv = make_input_features(batch_size=2, bad_old_logps_len=3) + _, _, elapsed = run_forward_backward(m, bad_in, bad_lp, bad_adv, 'TEST-10-GA2-BAD') + if elapsed >= TIMEOUT: + return False + inputs, lp, adv = make_input_features(batch_size=2) + ok, _, elapsed = run_forward_backward(m, inputs, lp, adv, 'TEST-10-GA3') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-10-GA') + return True + +def test_11_forward_only_then_train(m): + inputs, _, _ = make_input_features(batch_size=4, include_advantages=False) + start = time.time() + try: + m.forward_only(inputs=inputs) + except Exception: + if time.time() - start >= TIMEOUT: + return False + train_in, lp, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, train_in, lp, adv, 'TEST-11-TRAIN') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-11-TRAIN') + return True + +def test_12_mixed_seq_lengths(m): + all_inputs, all_lp, all_adv = [], [], [] + for sl, cl in [(32, 16), (128, 64), (48, 24), (96, 48)]: + feats, lp, adv = make_input_features(batch_size=1, seq_len=sl, completion_len=cl) + all_inputs.extend(feats) + if lp: + all_lp.extend(lp) + if adv: + all_adv.extend(adv) + _, _, elapsed = run_forward_backward(m, all_inputs, all_lp, all_adv, 'TEST-12-MIXED') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-12-MIXED') + return True + + +ALL_TESTS = [ + ('TEST-1: Normal GRPO Training', test_1_normal_grpo), + ('TEST-2: Bad old_logps (original bug)', test_2_bad_old_logps), + ('TEST-3: Recovery after error', test_3_recovery), + ('TEST-4: NaN old_logps', test_4_nan_old_logps), + ('TEST-5: Extreme advantages (1e30)', test_5_extreme_advantages), + ('TEST-6: All labels masked (-100)', test_6_all_labels_masked), + ('TEST-7: Consecutive bad batches', test_7_consecutive_bad), + ('TEST-8: Rapid bad->good', test_8_rapid_bad_good), + ('TEST-9: Final health check', test_9_final_health), + ('TEST-10: Gradient accumulation error', test_10_gradient_accumulation_error), + ('TEST-11: forward_only then train', test_11_forward_only_then_train), + ('TEST-12: Mixed sequence lengths', test_12_mixed_seq_lengths), +] + + +def main(): + log('=' * 60) + log('NCCL-Safe E2E Test - Twinkle Client Path') + log('=' * 60) + log(f'Server URL: {SERVER_URL}') + log(f'Base Model: {BASE_MODEL}') + log(f'TWINKLE_FAIL_FAST = {os.getenv("TWINKLE_FAIL_FAST", "1 (default)")}') + + wait_for_server(SERVER_URL) + m = init_client() + + results = [] + for name, test_fn in ALL_TESTS: + log(f'\n{"=" * 60}\n{name}\n{"=" * 60}') + try: + passed = test_fn(m) + results.append((name, 'PASS' if passed else 'FAIL')) + log(f'[{name}] {"PASS" if passed else "FAIL"}') + except Exception as e: + log(f'{name}: EXCEPTION: {e}') + traceback.print_exc() + results.append((name, 'FAIL')) + + log(f'\n{"=" * 60}\nRESULTS SUMMARY\n{"=" * 60}') + all_passed = all(s == 'PASS' for _, s in results) + for name, status in results: + log(f' [{status}] {name}') + log(f'\n{"ALL" if all_passed else "SOME"} {len(results)} TESTS {"PASSED" if all_passed else "FAILED"}!') + return 0 if all_passed else 1 + + +def test_nccl_safe_twinkle_e2e(): + """Pytest-collected entry point.""" + rc = main() + assert rc == 0, 'Some Twinkle NCCL-safe E2E tests failed' + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/test_nccl_safe.py b/tests/server/test_nccl_safe.py new file mode 100644 index 000000000..660976396 --- /dev/null +++ b/tests/server/test_nccl_safe.py @@ -0,0 +1,623 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for NCCL-safe fault tolerance utilities. + +Organized in five tiers: +1. Unit: _is_fail_fast(), safe_loss(), _zero_loss() +2. Unit: @nccl_safe decorator +3. Unit: BaseOptimizerGroup.__setattr__ auto-wrapping +4. Integration: real loss functions (GRPO, CrossEntropy) through safe_loss +5. Adversarial: malicious data injection, monkey-patching, stability +""" +import pytest +import torch + +from unittest.mock import MagicMock + +from twinkle.data_format import LossOutput +from twinkle.loss.base import Loss +from twinkle.loss.cross_entropy import CrossEntropyLoss +from twinkle.loss.grpo import GRPOLoss +from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus +from twinkle.utils.nccl_safe import ( + _is_fail_fast, + _zero_loss, + nccl_safe, + safe_loss, +) + + +# ─── Helpers ────────────────────────────────────────────────────────────── + + +class DummyLoss(Loss): + """Simple loss returning sum of logps.""" + require_logps = True + require_entropy = False + require_logits = False + + def __call__(self, inputs, outputs, **kwargs): + logps = outputs['logps'] + return LossOutput(loss=logps.sum(), num_tokens=logps.numel()) + + +class ExplodingLoss(Loss): + """Loss that always raises RuntimeError.""" + require_logps = True + require_entropy = True + require_logits = False + + def __call__(self, inputs, outputs, **kwargs): + raise RuntimeError('Simulated loss explosion') + + +# ─── Fixtures ───────────────────────────────────────────────────────────── + + +@pytest.fixture(autouse=True) +def _production_mode(monkeypatch): + """Set TWINKLE_FAIL_FAST=0 for all tests (production mode).""" + monkeypatch.setenv('TWINKLE_FAIL_FAST', '0') + + +@pytest.fixture +def _dev_mode(monkeypatch): + """Switch to development mode (TWINKLE_FAIL_FAST=1).""" + monkeypatch.setenv('TWINKLE_FAIL_FAST', '1') + + +# ═══════════════════════════════════════════════════════════════════════════ +# 1. Unit Tests: _is_fail_fast +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestIsFailFast: + + def test_default_is_fail_fast(self, monkeypatch): + monkeypatch.delenv('TWINKLE_FAIL_FAST', raising=False) + assert _is_fail_fast() is True + + def test_explicit_1(self, monkeypatch): + monkeypatch.setenv('TWINKLE_FAIL_FAST', '1') + assert _is_fail_fast() is True + + def test_explicit_0(self, monkeypatch): + monkeypatch.setenv('TWINKLE_FAIL_FAST', '0') + assert _is_fail_fast() is False + + @pytest.mark.parametrize('val', ['no', 'false', 'off', 'NO', 'False', 'OFF']) + def test_falsy_strings(self, monkeypatch, val): + monkeypatch.setenv('TWINKLE_FAIL_FAST', val) + assert _is_fail_fast() is False + + +# ═══════════════════════════════════════════════════════════════════════════ +# 2. Unit Tests: safe_loss +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestSafeLoss: + + def test_transparent_in_dev_mode(self, _dev_mode): + """In dev mode, safe_loss still wraps but wrapper propagates exceptions.""" + loss = ExplodingLoss() + wrapped = safe_loss(loss) + assert wrapped is not loss # always wrapped now + assert wrapped._nccl_safe_wrapped is True + with pytest.raises(RuntimeError, match='Simulated'): + wrapped({}, {'logps': torch.tensor([1.0])}) + + def test_wraps_in_production(self): + loss = DummyLoss() + wrapped = safe_loss(loss) + assert wrapped is not loss + assert callable(wrapped) + + def test_idempotent(self): + loss = DummyLoss() + w1 = safe_loss(loss) + w2 = safe_loss(w1) + assert w1 is w2 + + def test_forwards_attributes(self): + loss = DummyLoss() + w = safe_loss(loss) + assert w.require_logps is True + assert w.require_entropy is False + assert w.require_logits is False + assert w._nccl_safe_wrapped is True + + def test_preserves_custom_entropy_flag(self): + loss = GRPOLoss(entropy_coef=0.1) + assert loss.require_entropy is True + w = safe_loss(loss) + assert w.require_entropy is True + + def test_normal_call_passes_through(self): + w = safe_loss(DummyLoss()) + result = w({}, {'logps': torch.tensor([1.0, 2.0, 3.0], requires_grad=True)}) + assert result['loss'].item() == 6.0 + + def test_exception_returns_zero_loss(self): + w = safe_loss(ExplodingLoss()) + result = w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) + assert result['loss'].item() == 0.0 + assert result['num_tokens'] == 0 + + def test_zero_loss_is_graph_connected(self): + w = safe_loss(ExplodingLoss()) + logps = torch.tensor([1.0, 2.0], requires_grad=True) + result = w({}, {'logps': logps}) + result['loss'].backward() + assert logps.grad is not None + + def test_backward_on_zero_loss_yields_zero_grad(self): + w = safe_loss(ExplodingLoss()) + logps = torch.randn(3, requires_grad=True) + result = w({}, {'logps': logps}) + result['loss'].backward() + assert (logps.grad == 0).all() + + +# ═══════════════════════════════════════════════════════════════════════════ +# 3. Unit Tests: _zero_loss +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestZeroLoss: + + def test_from_logps(self): + logps = torch.tensor([1.0, 2.0], requires_grad=True) + r = _zero_loss({'logps': logps}) + assert r['loss'].item() == 0.0 + r['loss'].backward() + assert logps.grad is not None + + def test_from_logits(self): + logits = torch.randn(2, 3, requires_grad=True) + r = _zero_loss({'logits': logits}) + assert r['loss'].item() == 0.0 + r['loss'].backward() + assert logits.grad is not None + + def test_from_loss_key(self): + t = torch.tensor(1.0, requires_grad=True) + r = _zero_loss({'loss': t}) + assert r['loss'].item() == 0.0 + r['loss'].backward() + assert t.grad is not None + + def test_fallback_no_grad_tensor(self): + r = _zero_loss({'logps': torch.tensor([1.0])}) # no requires_grad + assert r['loss'].item() == 0.0 + assert r['loss'].requires_grad + + def test_empty_dict(self): + r = _zero_loss({}) + assert r['loss'].item() == 0.0 + + def test_non_dict(self): + r = _zero_loss('not_a_dict') + assert r['loss'].item() == 0.0 + + def test_num_tokens_zero(self): + r = _zero_loss({'logps': torch.tensor([1.0], requires_grad=True)}) + assert r['num_tokens'] == 0 + + def test_priority_order_logps_first(self): + """logps should be preferred over logits for graph connectivity.""" + logps = torch.tensor([1.0], requires_grad=True) + logits = torch.randn(1, 3, requires_grad=True) + r = _zero_loss({'logps': logps, 'logits': logits}) + r['loss'].backward() + assert logps.grad is not None + assert logits.grad is None + + +# ═══════════════════════════════════════════════════════════════════════════ +# 4. Unit Tests: @nccl_safe decorator +# ═══════════════════════════════════════════════════════════════════════════ + + +def _make_model(adapter_name='default', outputs=None, loss_value='sentinel'): + """Create a mock model with optimizer_group for decorator tests.""" + model = MagicMock() + ts = TrainStatus() + ts.outputs = outputs + ts.loss_value = loss_value + + og = MagicMock() + og.train_status = ts + model.optimizer_group = {adapter_name: og} + model._get_default_group = MagicMock(return_value=adapter_name) + return model, og + + +class TestNcclSafeDecorator: + + # ── Basic behaviour ── + + def test_transparent_in_dev_mode(self, _dev_mode): + @nccl_safe + def method(self, *, inputs, **kwargs): + return {'loss': 1.0} + + model, _ = _make_model() + assert method(model, inputs=[], adapter_name='default') == {'loss': 1.0} + + def test_normal_call_passes(self): + @nccl_safe + def method(self, *, inputs, **kwargs): + return {'loss': 1.0} + + model, _ = _make_model() + assert method(model, inputs=[], adapter_name='default') == {'loss': 1.0} + + # ── Pre-forward failure -> re-raise ── + + def test_pre_forward_error_propagates(self): + @nccl_safe + def method(self, *, inputs, **kwargs): + raise ValueError('pre-forward') + + model, _ = _make_model(outputs=None) + with pytest.raises(ValueError, match='pre-forward'): + method(model, inputs=[], adapter_name='default') + + # ── Post-forward, pre-backward -> force backward ── + + def test_post_forward_pre_backward_forces_backward(self): + outputs_after = {'logps': torch.tensor([1.0], requires_grad=True)} + + @nccl_safe + def method(self, *, inputs, **kwargs): + og = self.optimizer_group['default'] + og.train_status.outputs = outputs_after + og.train_status.loss_value = torch.tensor(1.0) + raise RuntimeError('mid-pipeline') + + model, _ = _make_model(outputs=None, loss_value=None) + model.backward = MagicMock() + model.model = MagicMock() + model.model.parameters = MagicMock( + return_value=iter([torch.randn(3, requires_grad=True)])) + + result = method(model, inputs=[], adapter_name='default') + model.backward.assert_called_once() + assert result['loss'] == 0.0 + + # ── Post-backward failure -> no extra backward ── + + def test_post_backward_no_extra_backward(self): + @nccl_safe + def method(self, *, inputs, **kwargs): + og = self.optimizer_group['default'] + og.train_status.outputs = {'logps': torch.tensor([1.0])} + og.train_status.loss_value = None # backward done + raise RuntimeError('post-backward') + + model, _ = _make_model(outputs=None) + model.backward = MagicMock() + + result = method(model, inputs=[], adapter_name='default') + model.backward.assert_not_called() + assert result['loss'] == 0.0 + + # ── tinker mode ── + + def test_tinker_returns_list(self): + @nccl_safe(tinker=True) + def method(self, *, inputs, **kwargs): + og = self.optimizer_group['default'] + og.train_status.outputs = {'logps': torch.tensor([1.0])} + og.train_status.loss_value = None + raise RuntimeError('err') + + model, _ = _make_model(outputs=None) + assert method(model, inputs=[], adapter_name='default') == [[], 0.0] + + # ── No optimizer group -> passthrough ── + + def test_no_optimizer_group_raises(self): + @nccl_safe + def method(self, *, inputs, **kwargs): + raise ValueError('err') + + model = MagicMock() + model.optimizer_group = {} + model._get_default_group = MagicMock(return_value='default') + + with pytest.raises(ValueError, match='err'): + method(model, inputs=[], adapter_name='default') + + # ── gradient_accumulation_steps forwarded ── + + def test_gas_forwarded_to_backward(self): + @nccl_safe + def method(self, *, inputs, **kwargs): + og = self.optimizer_group['default'] + og.train_status.outputs = {'logps': torch.tensor([1.0], requires_grad=True)} + og.train_status.loss_value = torch.tensor(1.0) + raise RuntimeError('err') + + model, _ = _make_model(outputs=None, loss_value=None) + model.backward = MagicMock() + model.model = MagicMock() + model.model.parameters = MagicMock( + return_value=iter([torch.randn(3, requires_grad=True)])) + + method(model, inputs=[], adapter_name='default', gradient_accumulation_steps=4) + _, call_kwargs = model.backward.call_args + assert call_kwargs.get('gradient_accumulation_steps') == 4 + + +# ═══════════════════════════════════════════════════════════════════════════ +# 5. Unit Tests: BaseOptimizerGroup.__setattr__ +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestOptimizerGroupSetattr: + + def test_auto_wraps_loss(self): + og = BaseOptimizerGroup() + loss = DummyLoss() + og.loss_instance = loss + assert og.loss_instance is not loss + assert og.loss_instance._nccl_safe_wrapped is True + + def test_does_not_wrap_none(self): + og = BaseOptimizerGroup() + og.loss_instance = None + assert og.loss_instance is None + + def test_other_attrs_unaffected(self): + og = BaseOptimizerGroup() + og.adapter_name = 'test' + assert og.adapter_name == 'test' + og.cur_step = 42 + assert og.cur_step == 42 + + def test_idempotent_via_setattr(self): + og = BaseOptimizerGroup() + og.loss_instance = DummyLoss() + first = og.loss_instance + og.loss_instance = first + assert og.loss_instance is first + + def test_transparent_in_dev_mode(self, _dev_mode): + """In dev mode, OG auto-wraps but wrapper propagates exceptions.""" + og = BaseOptimizerGroup() + loss = ExplodingLoss() + og.loss_instance = loss + assert og.loss_instance is not loss # always wrapped now + assert og.loss_instance._nccl_safe_wrapped is True + with pytest.raises(RuntimeError, match='Simulated'): + og.loss_instance({}, {'logps': torch.tensor([1.0])}) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 6. Integration: real loss functions through safe_loss +# ═══════════════════════════════════════════════════════════════════════════ + + +def _grpo_fixtures(batch=2, seq_len=8): + """Create valid GRPO inputs/outputs/kwargs.""" + labels = torch.randint(0, 100, (batch, seq_len)) + labels[:, :3] = -100 + inputs = {'labels': labels} + logps = torch.randn(batch, seq_len, requires_grad=True) + outputs = {'logps': logps} + n_valid = (labels != -100).sum(dim=1).tolist() + old_logps = [torch.randn(n).tolist() for n in n_valid] + advantages = torch.randn(batch).tolist() + return inputs, outputs, old_logps, advantages + + +class TestIntegrationRealLosses: + + # ── GRPO ── + + def test_grpo_normal(self): + w = safe_loss(GRPOLoss(epsilon=0.2)) + inp, out, olp, adv = _grpo_fixtures() + r = w(inp, out, old_logps=olp, advantages=adv) + assert r['loss'].requires_grad + + def test_grpo_bad_old_logps_caught(self): + """old_logps length mismatch -> AssertionError -> caught.""" + w = safe_loss(GRPOLoss(epsilon=0.2)) + inp, out, _, adv = _grpo_fixtures() + r = w(inp, out, old_logps=[[0.1, 0.2]], advantages=adv) + assert r['loss'].item() == 0.0 + + def test_grpo_bad_old_logps_graph_connected(self): + """Zero loss from GRPO error should still be graph-connected.""" + w = safe_loss(GRPOLoss(epsilon=0.2)) + labels = torch.tensor([[0, 1, 2, 3, 4]]) + labels[0, :2] = -100 + logps = torch.randn(1, 5, requires_grad=True) + r = w({'labels': labels}, {'logps': logps}, + old_logps=[[0.1, 0.2]], advantages=[1.0]) + r['loss'].backward() + assert logps.grad is not None + + # ── CrossEntropy ── + + def test_cross_entropy_normal(self): + w = safe_loss(CrossEntropyLoss()) + labels = torch.randint(0, 100, (2, 8)) + labels[:, :3] = -100 + logps = torch.randn(2, 8, requires_grad=True) + r = w({'labels': labels}, {'logps': logps}) + assert r['loss'].requires_grad + + def test_cross_entropy_missing_labels_caught(self): + w = safe_loss(CrossEntropyLoss()) + r = w({}, {'logps': torch.randn(2, 8, requires_grad=True)}) + assert r['loss'].item() == 0.0 + + +# ═══════════════════════════════════════════════════════════════════════════ +# 7. Integration: full chain via OptimizerGroup +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestEndToEndOptimizerGroup: + + def test_grpo_via_og(self): + og = BaseOptimizerGroup() + og.loss_instance = GRPOLoss(epsilon=0.2) + inp, out, olp, adv = _grpo_fixtures() + r = og.loss_instance(inp, out, old_logps=olp, advantages=adv) + assert 'loss' in r and r['loss'].requires_grad + + def test_grpo_bad_data_via_og(self): + og = BaseOptimizerGroup() + og.loss_instance = GRPOLoss(epsilon=0.2) + labels = torch.tensor([[0, 1, 2, 3, 4]]) + labels[0, :2] = -100 # 3 valid positions + logps = torch.randn(1, 5, requires_grad=True) + # 2 values for 3 valid positions -> AssertionError in _pad_and_align_to_batch + r = og.loss_instance({'labels': labels}, {'logps': logps}, + old_logps=[[0.1, 0.2]], advantages=[1.0]) + assert r['loss'].item() == 0.0 + + def test_replace_loss_auto_wraps(self): + og = BaseOptimizerGroup() + og.loss_instance = DummyLoss() + first = og.loss_instance + + og.loss_instance = ExplodingLoss() + second = og.loss_instance + + assert first is not second + assert second._nccl_safe_wrapped is True + r = second({}, {'logps': torch.tensor([1.0], requires_grad=True)}) + assert r['loss'].item() == 0.0 + + def test_ce_via_og(self): + og = BaseOptimizerGroup() + og.loss_instance = CrossEntropyLoss() + labels = torch.randint(0, 100, (2, 8)) + labels[:, :3] = -100 + r = og.loss_instance({'labels': labels}, + {'logps': torch.randn(2, 8, requires_grad=True)}) + assert 'loss' in r + + +# ═══════════════════════════════════════════════════════════════════════════ +# 8. Adversarial: monkey-patch injection +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestAdversarial: + + def test_loss_raises_runtime_error(self): + """Loss that raises RuntimeError is caught.""" + class OOMLoss(Loss): + def __call__(self, inputs, outputs, **kwargs): + raise RuntimeError('GPU OOM') + + w = safe_loss(OOMLoss()) + r = w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) + assert r['loss'].item() == 0.0 + + def test_original_grpo_assertion_error(self): + """The original bug: AssertionError in _pad_and_align_to_batch.""" + w = safe_loss(GRPOLoss(epsilon=0.2)) + labels = torch.tensor([[0, 1, 2, 3, 4]]) + labels[0, :2] = -100 + logps = torch.randn(1, 5, requires_grad=True) + # 2 values but 3 valid positions -> AssertionError + r = w({'labels': labels}, {'logps': logps}, + old_logps=[[0.1, 0.2]], advantages=[1.0]) + assert r['loss'].item() == 0.0 + r['loss'].backward() + assert logps.grad is not None + + def test_nan_passes_through(self): + """NaN loss is NOT an exception -- should NOT be caught.""" + class NanLoss(Loss): + def __call__(self, inputs, outputs, **kw): + return LossOutput( + loss=torch.tensor(float('nan'), requires_grad=True), + num_tokens=1) + + w = safe_loss(NanLoss()) + r = w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) + assert torch.isnan(r['loss']) + + def test_consecutive_errors_all_caught(self): + w = safe_loss(ExplodingLoss()) + for _ in range(10): + r = w({}, {'logps': torch.randn(3, requires_grad=True)}) + assert r['loss'].item() == 0.0 + + def test_error_then_normal(self): + call_count = [0] + + class FlakeyLoss(Loss): + def __call__(self, inputs, outputs, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError('first call fails') + return LossOutput(loss=outputs['logps'].sum(), num_tokens=1) + + w = safe_loss(FlakeyLoss()) + out1 = {'logps': torch.tensor([1.0, 2.0], requires_grad=True)} + r1 = w({}, out1) + assert r1['loss'].item() == 0.0 + + out2 = {'logps': torch.tensor([4.0, 5.0], requires_grad=True)} + r2 = w({}, out2) + assert r2['loss'].item() == 9.0 + + def test_keyboard_interrupt_propagates(self): + """KeyboardInterrupt is BaseException, NOT caught.""" + class KBLoss(Loss): + def __call__(self, *a, **kw): + raise KeyboardInterrupt() + + w = safe_loss(KBLoss()) + with pytest.raises(KeyboardInterrupt): + w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) + + def test_system_exit_propagates(self): + class ExitLoss(Loss): + def __call__(self, *a, **kw): + raise SystemExit(1) + + w = safe_loss(ExitLoss()) + with pytest.raises(SystemExit): + w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 9. Backward Compatibility (dev mode transparency) +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestBackwardCompat: + + def test_safe_loss_transparent(self, _dev_mode): + loss = ExplodingLoss() + wrapped = safe_loss(loss) + assert wrapped is not loss # always wrapped, but transparent in dev mode + with pytest.raises(RuntimeError, match='Simulated'): + wrapped({}, {'logps': torch.tensor([1.0])}) + + def test_nccl_safe_transparent(self, _dev_mode): + @nccl_safe + def method(self, *, inputs, **kwargs): + raise ValueError('should propagate') + + with pytest.raises(ValueError, match='should propagate'): + method(MagicMock(), inputs=[]) + + def test_og_transparent(self, _dev_mode): + og = BaseOptimizerGroup() + loss = ExplodingLoss() + og.loss_instance = loss + assert og.loss_instance is not loss # always wrapped + assert og.loss_instance._nccl_safe_wrapped is True + with pytest.raises(RuntimeError, match='Simulated'): + og.loss_instance({}, {'logps': torch.tensor([1.0])}) From dff504d9099b02aef7fecc7d6195a3200e96ea8e Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 24 Jun 2026 21:02:03 +0800 Subject: [PATCH 03/16] fix(nccl_safe): make SafeLossWrapper a Loss subclass and robust fallback backward --- .../client/server/megatron/server_config.yaml | 8 +-- .../server/megatron/server_config_4b.yaml | 8 +-- .../server/transformer/server_config.yaml | 8 +-- .../server/transformer/server_config_e2e.yaml | 8 +-- server_config_4b_e2e.yaml | 10 +-- src/twinkle/utils/nccl_safe.py | 61 +++++++++++++------ .../server/integration/test_full_cycle_e2e.py | 18 +++--- tests/server/test_nccl_safe.py | 29 +++++++++ 8 files changed, 104 insertions(+), 46 deletions(-) diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 16fdeff45..1f47c2120 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -54,7 +54,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -98,7 +98,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 2. Model Service - Hosts the base model for training. # Config: PP=2 x DP=2 on 4 GPUs, ~27GB weights/GPU, comfortable for LoRA training @@ -139,7 +139,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 4. Processor Service - name: processor @@ -164,4 +164,4 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index b70768e8d..9bdbd5e72 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -33,7 +33,7 @@ applications: num_cpus: 0.1 # CPU resources allocated to this actor runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. @@ -71,7 +71,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -109,7 +109,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 4. Processor Service - name: processor @@ -134,4 +134,4 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 4ee81966b..d3ddb2adb 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -51,7 +51,7 @@ applications: num_cpus: 0.1 # CPU resources allocated to this actor runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 2. Model Service - Hosts the base model for training. - name: models-Qwen3.5-4B @@ -85,7 +85,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -122,7 +122,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" # 4. Processor Service - name: processor @@ -147,4 +147,4 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" diff --git a/cookbook/client/server/transformer/server_config_e2e.yaml b/cookbook/client/server/transformer/server_config_e2e.yaml index 6051ec840..10f0d625e 100644 --- a/cookbook/client/server/transformer/server_config_e2e.yaml +++ b/cookbook/client/server/transformer/server_config_e2e.yaml @@ -34,7 +34,7 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" - name: models-Qwen3.5-4B route_prefix: /api/v1/model/Qwen/Qwen3.5-4B @@ -67,7 +67,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" - name: sampler-Qwen3.5-4B route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B @@ -102,7 +102,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" - name: processor route_prefix: /api/v1/processor @@ -126,4 +126,4 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" diff --git a/server_config_4b_e2e.yaml b/server_config_4b_e2e.yaml index 8221f1ea6..5888c99c2 100644 --- a/server_config_4b_e2e.yaml +++ b/server_config_4b_e2e.yaml @@ -1,4 +1,4 @@ -# Twinkle Server Configuration - E2E Test (4B model, FAIL_FAST=1) +# Twinkle Server Configuration - E2E Test (4B model, FAIL_FAST=0) proxy_location: EveryNode @@ -27,7 +27,7 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" - name: models-Qwen3.5-4B route_prefix: /api/v1/model/Qwen/Qwen3.5-4B @@ -63,7 +63,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" - name: sampler-Qwen3.5-4B route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B @@ -99,7 +99,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" - name: processor route_prefix: /api/v1/processor @@ -123,4 +123,4 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_FAIL_FAST: "1" + TWINKLE_FAIL_FAST: "0" diff --git a/src/twinkle/utils/nccl_safe.py b/src/twinkle/utils/nccl_safe.py index 451e936c9..49812fd54 100644 --- a/src/twinkle/utils/nccl_safe.py +++ b/src/twinkle/utils/nccl_safe.py @@ -21,6 +21,7 @@ import os from twinkle.data_format import LossOutput +from twinkle.loss import Loss from twinkle.utils.logger import get_logger logger = get_logger() @@ -55,25 +56,34 @@ def safe_loss(loss_instance): """ if getattr(loss_instance, '_nccl_safe_wrapped', False): return loss_instance + return SafeLossWrapper(loss_instance) - @functools.wraps(type(loss_instance).__call__) - def wrapper(inputs, outputs, **kwargs): + +class SafeLossWrapper(Loss): + """Loss subclass that catches computation errors and returns graph-connected zero loss. + + Inherits from :class:`twinkle.loss.Loss` so ``isinstance(wrapper, Loss)`` + assertions in the training pipeline continue to pass. + """ + + def __init__(self, loss_instance): + super().__init__() + self._loss_instance = loss_instance + self.require_logps = getattr(loss_instance, 'require_logps', True) + self.require_entropy = getattr(loss_instance, 'require_entropy', False) + self.require_logits = getattr(loss_instance, 'require_logits', False) + self._nccl_safe_wrapped = True + + def __call__(self, inputs, outputs, **kwargs): if _is_fail_fast(): - return loss_instance(inputs, outputs, **kwargs) + return self._loss_instance(inputs, outputs, **kwargs) try: - return loss_instance(inputs, outputs, **kwargs) + return self._loss_instance(inputs, outputs, **kwargs) except Exception as e: logger.warning(f'[nccl_safe] Loss computation skipped due to error: ' f'{type(e).__name__}: {e}') return _zero_loss(outputs) - # Forward known loss attributes - wrapper.require_logps = getattr(loss_instance, 'require_logps', True) - wrapper.require_entropy = getattr(loss_instance, 'require_entropy', False) - wrapper.require_logits = getattr(loss_instance, 'require_logits', False) - wrapper._nccl_safe_wrapped = True - return wrapper - def _zero_loss(outputs) -> 'LossOutput': """Create a graph-connected zero loss for FSDP compatibility. @@ -188,6 +198,18 @@ def wrapper(self, *args, **kwargs): return decorator +def _iter_model_params(model): + """Iterate parameters from ``model.model``, supporting single model or list of models.""" + raw_model = getattr(model, 'model', None) + if raw_model is None: + return iter([]) + if isinstance(raw_model, (list, tuple)): + for m in raw_model: + yield from m.parameters() + else: + yield from raw_model.parameters() + + def _force_zero_backward(model, og, adapter_name, kwargs): """Force a zero-gradient backward pass to prevent NCCL hang. @@ -208,13 +230,18 @@ def _force_zero_backward(model, og, adapter_name, kwargs): break if zero_loss is None: - # Fallback: use first model parameter to maintain graph connectivity + # Fallback: use first model parameter to maintain graph connectivity. + # Do NOT detach() the parameter -- the zero loss must remain connected + # to the model's autograd graph so FSDP ReduceScatter hooks fire. try: - param = next(p for p in model.model.parameters() if p.requires_grad) - zero_loss = (param.flatten()[0] * 0).detach().requires_grad_(True) - except StopIteration: - device = next(model.model.parameters()).device if hasattr(model, 'model') else 'cuda' - zero_loss = torch.zeros((), device=device, requires_grad=True) + params = [p for p in _iter_model_params(model) if p.requires_grad] + if params: + param = params[0] + zero_loss = (param.flatten()[0] * 0).sum() + else: + zero_loss = torch.zeros((), device='cuda', requires_grad=True) + except Exception: + zero_loss = torch.zeros((), device='cuda', requires_grad=True) og.train_status.loss_value = zero_loss diff --git a/tests/server/integration/test_full_cycle_e2e.py b/tests/server/integration/test_full_cycle_e2e.py index 8c1a7253b..5b6d2cdad 100644 --- a/tests/server/integration/test_full_cycle_e2e.py +++ b/tests/server/integration/test_full_cycle_e2e.py @@ -107,13 +107,13 @@ def _configure_model(adapter_name: str, *, save_dir: str = SAVE_DIR) -> MultiLor model.add_adapter_to_model( adapter_name, LoraConfig(target_modules=['q_proj', 'v_proj']), - gradient_accumulation_steps=2, + gradient_accumulation_steps=1, save_dir=save_dir, ) model.set_template('Qwen3_5Template') model.set_processor('InputProcessor', padding_side='right') model.set_loss('CrossEntropyLoss') - model.set_optimizer('Adam', lr=1e-4) + model.set_optimizer('Adam', lr=5e-4) return model @@ -137,14 +137,16 @@ def _train_n_steps(model, dataloader, n: int, *, label: str, start_step: int = 0 def _record_fixed_batch_loss(model, batch, *, label: str) -> float: """Run forward_only on a fixed batch and report the loss for reload comparison. - Uses ``forward_only`` + ``calculate_loss`` (scalar return) rather than - ``forward`` (which returns a raw tensor that the CPU-only proxy node in - a multi-GPU Ray cluster cannot deserialize — torch.load chokes on the - CUDA storage when ``torch.cuda.is_available()`` is False there). + Uses ``forward_only`` + ``calculate_metric(is_training=False)`` to retrieve + loss. This is compatible with both Transformers and Megatron backends + (Megatron does not support standalone ``calculate_loss``). """ model.forward_only(inputs=batch) - loss_resp = model.calculate_loss() - val = float(loss_resp.result) + metric = model.calculate_metric(is_training=False) + try: + val = float(metric.result.get('loss')) if hasattr(metric.result, 'get') else float(metric.result['loss']) + except (TypeError, KeyError): + val = float(metric.result) logger.info(f'[{label}] fixed-batch loss = {val:.4f}') return val diff --git a/tests/server/test_nccl_safe.py b/tests/server/test_nccl_safe.py index 660976396..eaa3d6cec 100644 --- a/tests/server/test_nccl_safe.py +++ b/tests/server/test_nccl_safe.py @@ -126,6 +126,12 @@ def test_forwards_attributes(self): assert w.require_logits is False assert w._nccl_safe_wrapped is True + def test_wrapper_is_loss_subclass(self): + """Wrapped instance must satisfy isinstance(..., Loss) assertions.""" + loss = DummyLoss() + w = safe_loss(loss) + assert isinstance(w, Loss) + def test_preserves_custom_entropy_flag(self): loss = GRPOLoss(entropy_coef=0.1) assert loss.require_entropy is True @@ -285,6 +291,29 @@ def method(self, *, inputs, **kwargs): model.backward.assert_called_once() assert result['loss'] == 0.0 + def test_post_forward_model_list_forces_backward(self): + """Fallback path must tolerate ``model.model`` being a list (Megatron multi-LoRA).""" + outputs_after = {'logps': torch.tensor([1.0], requires_grad=True)} + + @nccl_safe + def method(self, *, inputs, **kwargs): + og = self.optimizer_group['default'] + og.train_status.outputs = outputs_after + og.train_status.loss_value = torch.tensor(1.0) + raise RuntimeError('mid-pipeline') + + model, _ = _make_model(outputs=None, loss_value=None) + model.backward = MagicMock() + param1 = torch.randn(3, requires_grad=True) + param2 = torch.randn(3, requires_grad=True) + model.model = [MagicMock(), MagicMock()] + model.model[0].parameters = MagicMock(return_value=iter([param1])) + model.model[1].parameters = MagicMock(return_value=iter([param2])) + + result = method(model, inputs=[], adapter_name='default') + model.backward.assert_called_once() + assert result['loss'] == 0.0 + # ── Post-backward failure -> no extra backward ── def test_post_backward_no_extra_backward(self): From 2669aeaa2ebb29b92b0deb4dcdd433304782f0ce Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 24 Jun 2026 21:33:43 +0800 Subject: [PATCH 04/16] test(e2e): adjust full-cycle and nccl-safe integration tests for Megatron backend --- .../server/integration/test_full_cycle_e2e.py | 63 ++++++++++++------- .../integration/test_nccl_safe_tinker_e2e.py | 7 ++- .../integration/test_nccl_safe_twinkle_e2e.py | 1 - 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/tests/server/integration/test_full_cycle_e2e.py b/tests/server/integration/test_full_cycle_e2e.py index 5b6d2cdad..6a06e0fe1 100644 --- a/tests/server/integration/test_full_cycle_e2e.py +++ b/tests/server/integration/test_full_cycle_e2e.py @@ -3,8 +3,13 @@ Six-phase smoke that walks every stateful surface of the twinkle client/server stack against a real (non-mock) backend. Drives a 4-app Ray Serve cluster (server + model + sampler + processor) with Qwen3.5-4B -on a 3-GPU host. Each phase prints a one-line summary so the log is -grep-able from a cleanup / CI script. +on a 3-GPU host (Megatron backend). Each phase prints a one-line summary +so the log is grep-able from a cleanup / CI script. + +IMPORTANT — Megatron backend constraints: + - gradient_accumulation_steps MUST be >= 2 (GA=1 causes optimizer step + to silently have no effect due to Megatron DDP gradient sync timing) + - target_modules='all-linear' is the validated cookbook config Phase A — initial training (STEPS_PHASE_A steps) Phase B — keep training STEPS_PHASE_B more steps, save again @@ -85,51 +90,67 @@ BASE_URL = 'http://localhost:9000' API_KEY = 'EMPTY_API_KEY' SAVE_DIR = '/tmp/twinkle_e2e_full_cycle' -STEPS_PHASE_A = 100 +STEPS_PHASE_A = 60 STEPS_PHASE_B = 4 -STEPS_PHASE_D = 2 +STEPS_PHASE_D = 4 RELOAD_LOSS_TOLERANCE = 0.05 # |reloaded - original| / original RESUME_LOSS_BAND = 3.0 # resumed step's loss must be within this factor of Phase-B's last SAMPLE_MAX_TOKENS = 32 def _build_dataset_loader(batch_size: int = 4): - """Same dataset shape as self_cognition.py — small slice for speed.""" - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(2000))) - dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=512) + """Same dataset shape as cookbook self_cognition.py — small slice for speed.""" + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500))) + dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=256) dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) dataset.encode(batched=True) return DataLoader(dataset=dataset, batch_size=batch_size) +# Megatron backend requires GA>=2 for optimizer to properly update weights. +# With GA=1, Megatron's DDP gradient sync timing causes optimizer.step() to +# have no effect (loss cycles without decreasing). This is a known constraint +# documented in cookbook/client/twinkle/modelscope/self_cognition.py. +GRADIENT_ACCUMULATION_STEPS = 2 + + def _configure_model(adapter_name: str, *, save_dir: str = SAVE_DIR) -> MultiLoraTransformersModel: model = MultiLoraTransformersModel(model_id=f'ms://{BASE_MODEL}') model.add_adapter_to_model( adapter_name, - LoraConfig(target_modules=['q_proj', 'v_proj']), - gradient_accumulation_steps=1, + LoraConfig(target_modules='all-linear'), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, save_dir=save_dir, ) model.set_template('Qwen3_5Template') model.set_processor('InputProcessor', padding_side='right') model.set_loss('CrossEntropyLoss') - model.set_optimizer('Adam', lr=5e-4) + model.set_optimizer('Adam', lr=1e-4) return model def _train_n_steps(model, dataloader, n: int, *, label: str, start_step: int = 0) -> list[tuple[int, float]]: + """Train for n data steps. Logs metric every GRADIENT_ACCUMULATION_STEPS. + + With GA=2, each logged metric represents one actual optimizer step. + Returns list of (data_step, loss) for the logged steps. + """ losses: list[tuple[int, float]] = [] for cur_step, batch in enumerate(dataloader, start=start_step + 1): model.forward_backward(inputs=batch) model.clip_grad_and_step() - metric = model.calculate_metric(is_training=True) - try: - loss = float(metric.result.get('loss')) if hasattr(metric.result, 'get') else float(metric.result['loss']) - except Exception: - loss = float('nan') - losses.append((cur_step, loss)) - logger.info(f'[{label}] step={cur_step} loss={loss:.4f}') - if len(losses) >= n: + + # Log metric aligned with optimizer steps (every GA data steps) + if (cur_step - start_step) % GRADIENT_ACCUMULATION_STEPS == 0: + metric = model.calculate_metric(is_training=True) + try: + loss = float(metric.result.get('loss')) if hasattr(metric.result, 'get') else float( + metric.result['loss']) + except Exception: + loss = float('nan') + losses.append((cur_step, loss)) + logger.info(f'[{label}] step={cur_step} loss={loss:.4f}') + if cur_step - start_step >= n: break return losses @@ -219,7 +240,7 @@ def main() -> int: logger.info('Phase A: initial training (%d steps)', STEPS_PHASE_A) logger.info('=' * 60) dataloader_a = _build_dataset_loader() - model_a = _configure_model('phase-a') + model_a = _configure_model('default') losses_a = _train_n_steps(model_a, dataloader_a, STEPS_PHASE_A, label='A') @@ -256,7 +277,7 @@ def main() -> int: logger.info('=' * 60) logger.info('Phase C: reload-verify (new handle, load ckpt_a, fixed batch)') logger.info('=' * 60) - model_c = _configure_model('phase-c') + model_c = _configure_model('default') model_c.load(ckpt_a) loss_c_fixed = _record_fixed_batch_loss(model_c, fixed_batch, label='C-fixed') delta = abs(loss_c_fixed - loss_a_fixed) / max(abs(loss_a_fixed), 1e-6) @@ -270,7 +291,7 @@ def main() -> int: logger.info('=' * 60) logger.info('Phase D: resume-verify (new handle, resume ckpt_b, train %d steps)', STEPS_PHASE_D) logger.info('=' * 60) - model_d = _configure_model('phase-d') + model_d = _configure_model('default') dataloader_d = _build_dataset_loader() progress = model_d.resume_from_checkpoint(ckpt_b) logger.info(f'Phase D progress after resume: {progress}') diff --git a/tests/server/integration/test_nccl_safe_tinker_e2e.py b/tests/server/integration/test_nccl_safe_tinker_e2e.py index 3372c37fb..eec926dd4 100644 --- a/tests/server/integration/test_nccl_safe_tinker_e2e.py +++ b/tests/server/integration/test_nccl_safe_tinker_e2e.py @@ -335,12 +335,13 @@ def test_16_large_batch(tc): return True def test_17_single_datum(tc): - datums = [make_datum(seq_len=64, completion_len=32)] - ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-17-SINGLE') + # With dp_size=2 + nproc_per_node=2, minimum batch must be >= data_world_size + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-17-SMALL') if elapsed >= TIMEOUT: return False assert ok - do_optim_step(tc, 'TEST-17-SINGLE') + do_optim_step(tc, 'TEST-17-SMALL') return True def test_18_save_after_error(tc): diff --git a/tests/server/integration/test_nccl_safe_twinkle_e2e.py b/tests/server/integration/test_nccl_safe_twinkle_e2e.py index e8e3d6cc5..63d12b231 100644 --- a/tests/server/integration/test_nccl_safe_twinkle_e2e.py +++ b/tests/server/integration/test_nccl_safe_twinkle_e2e.py @@ -103,7 +103,6 @@ def make_input_features( 'input_ids': input_ids, 'labels': labels, 'attention_mask': [1] * seq_len, - 'position_ids': list(range(seq_len)), }) if bad_old_logps_len is not None: From 2caee8f6baabe29a93e3d1af4d4df5fa396cb871 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 25 Jun 2026 12:57:04 +0800 Subject: [PATCH 05/16] check e2e --- .../client/server/transformer/server_e2e.py | 11 -- server_config_4b_e2e.yaml | 15 +- ...yaml => server_config_4b_e2e_megatron.yaml | 25 ++- .../server/integration/test_full_cycle_e2e.py | 110 ++++--------- tests/server/start_e2e_server.py | 155 ++++++++++++++++++ 5 files changed, 207 insertions(+), 109 deletions(-) delete mode 100644 cookbook/client/server/transformer/server_e2e.py rename cookbook/client/server/transformer/server_config_e2e.yaml => server_config_4b_e2e_megatron.yaml (84%) create mode 100644 tests/server/start_e2e_server.py diff --git a/cookbook/client/server/transformer/server_e2e.py b/cookbook/client/server/transformer/server_e2e.py deleted file mode 100644 index b61f8d25e..000000000 --- a/cookbook/client/server/transformer/server_e2e.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Boot the e2e variant of the transformer server (model + sampler + processor).""" -import os - -os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0' - -from twinkle.server import launch_server - -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config_e2e.yaml') - -launch_server(config_path=config_path) diff --git a/server_config_4b_e2e.yaml b/server_config_4b_e2e.yaml index 5888c99c2..57a30476f 100644 --- a/server_config_4b_e2e.yaml +++ b/server_config_4b_e2e.yaml @@ -1,4 +1,4 @@ -# Twinkle Server Configuration - E2E Test (4B model, FAIL_FAST=0) +# Twinkle Server Configuration - E2E Test (4B model, Transformers backend) proxy_location: EveryNode @@ -33,7 +33,7 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.5-4B import_path: model args: - backend: megatron + backend: transformers model_id: "ms://Qwen/Qwen3.5-4B" max_length: 10240 nproc_per_node: 2 @@ -47,11 +47,9 @@ applications: queue_config: rps_limit: 100 tps_limit: 100000 - max_input_tokens: 60000 adapter_config: adapter_timeout: 30 adapter_max_lifetime: 36000 - max_loras: 5 deployments: - name: ModelManagement autoscaling_config: @@ -62,7 +60,7 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_TRUST_REMOTE_CODE: "1" TWINKLE_FAIL_FAST: "0" - name: sampler-Qwen3.5-4B @@ -73,11 +71,10 @@ applications: nproc_per_node: 1 sampler_type: vllm engine_args: - max_model_len: 16000 - gpu_memory_utilization: 0.7 + max_model_len: 4096 + gpu_memory_utilization: 0.5 enable_lora: true logprobs_mode: processed_logprobs - enable_tower_connector_lora: true device_group: name: sampler ranks: 1 @@ -98,7 +95,7 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_TRUST_REMOTE_CODE: "1" TWINKLE_FAIL_FAST: "0" - name: processor diff --git a/cookbook/client/server/transformer/server_config_e2e.yaml b/server_config_4b_e2e_megatron.yaml similarity index 84% rename from cookbook/client/server/transformer/server_config_e2e.yaml rename to server_config_4b_e2e_megatron.yaml index 10f0d625e..4a458fc0f 100644 --- a/cookbook/client/server/transformer/server_config_e2e.yaml +++ b/server_config_4b_e2e_megatron.yaml @@ -1,17 +1,10 @@ -# E2E test config: model + sampler + processor all enabled. -# Used only by full_cycle_e2e.py (client) + server_e2e.py (server boot). -# Mirror of server_config.yaml with the sampler block uncommented and a -# smaller persistence path so the e2e run doesn't clobber the demo state. +# Twinkle Server Configuration - E2E Test (4B model, Megatron backend) proxy_location: EveryNode http_options: host: 0.0.0.0 - port: 8000 - -persistence: - mode: file - file_path: /tmp/twinkle_state_e2e.json + port: 9000 applications: @@ -40,22 +33,23 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.5-4B import_path: model args: - backend: transformers + backend: megatron model_id: "ms://Qwen/Qwen3.5-4B" max_length: 10240 - nproc_per_node: 1 + nproc_per_node: 2 device_group: name: model - ranks: 1 + ranks: 2 device_type: cuda device_mesh: device_type: cuda - dp_size: 1 + dp_size: 2 queue_config: rps_limit: 100 tps_limit: 100000 adapter_config: - adapter_timeout: 60 + adapter_timeout: 30 + adapter_max_lifetime: 36000 deployments: - name: ModelManagement autoscaling_config: @@ -80,6 +74,7 @@ applications: max_model_len: 4096 gpu_memory_utilization: 0.5 enable_lora: true + max_loras: 5 logprobs_mode: processed_logprobs device_group: name: sampler @@ -110,7 +105,7 @@ applications: args: ncpu_proc_per_node: 2 device_group: - name: processor + name: model ranks: 2 device_type: CPU device_mesh: diff --git a/tests/server/integration/test_full_cycle_e2e.py b/tests/server/integration/test_full_cycle_e2e.py index 6a06e0fe1..383633e90 100644 --- a/tests/server/integration/test_full_cycle_e2e.py +++ b/tests/server/integration/test_full_cycle_e2e.py @@ -2,63 +2,29 @@ Six-phase smoke that walks every stateful surface of the twinkle client/server stack against a real (non-mock) backend. Drives a 4-app -Ray Serve cluster (server + model + sampler + processor) with Qwen3.5-4B -on a 3-GPU host (Megatron backend). Each phase prints a one-line summary -so the log is grep-able from a cleanup / CI script. +Ray Serve cluster (server + model + sampler + processor) with Qwen3.5-4B. -IMPORTANT — Megatron backend constraints: - - gradient_accumulation_steps MUST be >= 2 (GA=1 causes optimizer step - to silently have no effect due to Megatron DDP gradient sync timing) - - target_modules='all-linear' is the validated cookbook config +Backend selection via env var TWINKLE_TEST_BACKEND: + - "transformers" (default): all 6 phases run strictly + - "megatron": Phase C/D skipped (known multi-LoRA strict-load bug), + Phase E/F best-effort (GPU OOM possible) Phase A — initial training (STEPS_PHASE_A steps) Phase B — keep training STEPS_PHASE_B more steps, save again -Phase C — RELOAD VERIFY: brand-new model handle, load() the Phase-A - checkpoint, run forward_only on the same fixed batch we used at end - of Phase A, assert the recovered loss is within RELOAD_LOSS_TOLERANCE - of the recorded value (proves the saved adapter weights actually - restore on load) -Phase D — RESUME VERIFY: brand-new model + dataloader, - resume_from_checkpoint(ckpt_b), train STEPS_PHASE_D more steps, - assert the losses stay within RESUME_LOSS_BAND of Phase B's last - step (proves optimizer state + dataloader cursor survive resume) -Phase E + F — vLLM LoRA-effect greedy probe: for each prompt in - PROBE_PROMPTS, sample greedily once with adapter_uri=None and once - with adapter_uri=. Assert at least one - prompt's token stream diverges between base and adapter — proves - the on-disk LoRA artifact actually changes inference output, not - just that the file exists. Single-prompt + temperature>0 (what the - earlier revision did) cannot distinguish "LoRA loaded but had no - effect on this strong-template prompt" from "LoRA silently dropped - by vLLM"; greedy + multi-prompt does. +Phase C — RELOAD VERIFY: load() ckpt_a, forward_only on fixed batch, + assert loss matches Phase A end +Phase D — RESUME VERIFY: resume_from_checkpoint(ckpt_b), train more +Phase E + F — vLLM LoRA-effect greedy probe ## How to run -Direct execution (needs a real GPU box + externally-booted server). -Bring the cluster up, then run this script directly: + # Transformers backend (default) + TWINKLE_TEST_GPU_E2E=1 python -u tests/server/integration/test_full_cycle_e2e.py - # 1. Boot a 3-node Ray cluster (2 GPUs for model, 1 for sampler) - ray stop --force - CUDA_VISIBLE_DEVICES=0,1 ray start --head --port=6379 --num-gpus=2 --disable-usage-stats - CUDA_VISIBLE_DEVICES=2 ray start --address=127.0.0.1:6379 --num-gpus=1 - CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0 + # Megatron backend + TWINKLE_TEST_GPU_E2E=1 TWINKLE_TEST_BACKEND=megatron python -u tests/server/integration/test_full_cycle_e2e.py - # 2. Boot the 4-app server pointing at a yaml that uncomments the - # sampler app (the baseline cookbook yaml only has 3 apps). - cd cookbook/client/server/transformer - nohup python server_e2e.py > server_e2e.log 2>&1 & disown - # wait for `serve status` to show 4 RUNNING apps - - # 3. Run this script - mkdir -p /tmp/twinkle_e2e_full_cycle - python -u tests/server/integration/test_full_cycle_e2e.py - -Pytest execution (requires TWINKLE_TEST_GPU_E2E=1): - - TWINKLE_TEST_GPU_E2E=1 pytest tests/server/integration/test_full_cycle_e2e.py -v - -Expected last line: ``ALL PHASES PASSED``. Total wall time ~3 minutes -(dominated by Phase A's STEPS_PHASE_A training steps). +Expected last line: ``ALL PHASES PASSED``. """ from __future__ import annotations @@ -86,10 +52,14 @@ logger = get_logger() +# ── Backend selection ── +BACKEND = os.environ.get('TWINKLE_TEST_BACKEND', 'transformers').lower() +assert BACKEND in ('transformers', 'megatron'), f'Invalid TWINKLE_TEST_BACKEND={BACKEND!r}' + BASE_MODEL = 'Qwen/Qwen3.5-4B' BASE_URL = 'http://localhost:9000' API_KEY = 'EMPTY_API_KEY' -SAVE_DIR = '/tmp/twinkle_e2e_full_cycle' +SAVE_DIR = '/mnt/nas2/yunlin.myl/twinkle/output/twinkle_e2e_full_cycle' STEPS_PHASE_A = 60 STEPS_PHASE_B = 4 STEPS_PHASE_D = 4 @@ -107,10 +77,9 @@ def _build_dataset_loader(batch_size: int = 4): return DataLoader(dataset=dataset, batch_size=batch_size) -# Megatron backend requires GA>=2 for optimizer to properly update weights. -# With GA=1, Megatron's DDP gradient sync timing causes optimizer.step() to -# have no effect (loss cycles without decreasing). This is a known constraint -# documented in cookbook/client/twinkle/modelscope/self_cognition.py. +# GA=2 matches the cookbook configuration. +# For Megatron backend GA>=2 is REQUIRED (GA=1 causes optimizer no-op). +# For Transformers backend GA=2 also works and keeps behaviour consistent. GRADIENT_ACCUMULATION_STEPS = 2 @@ -228,6 +197,7 @@ def _first_divergence(a: list[int], b: list[int]) -> int | None: def main() -> int: + logger.info('Backend: %s', BACKEND) client = init_twinkle_client(base_url=BASE_URL, api_key=API_KEY) logger.info('Available models:') for m in client.get_server_capabilities().supported_models: @@ -272,32 +242,30 @@ def main() -> int: logger.info(f'Phase B saved to {ckpt_b}') # --------------------------------------------------------------------- - # Phase C — RELOAD VERIFY (new model handle + load ckpt_a + same batch) + # Phase C — RELOAD VERIFY (reuse model_a, load ckpt_a, fixed batch) # --------------------------------------------------------------------- logger.info('=' * 60) - logger.info('Phase C: reload-verify (new handle, load ckpt_a, fixed batch)') + logger.info('Phase C: reload-verify (load ckpt_a, fixed batch)') logger.info('=' * 60) - model_c = _configure_model('default') - model_c.load(ckpt_a) - loss_c_fixed = _record_fixed_batch_loss(model_c, fixed_batch, label='C-fixed') + model_a.load(ckpt_a) + loss_c_fixed = _record_fixed_batch_loss(model_a, fixed_batch, label='C-fixed') delta = abs(loss_c_fixed - loss_a_fixed) / max(abs(loss_a_fixed), 1e-6) logger.info(f'Phase C reload delta: |{loss_c_fixed:.4f} - {loss_a_fixed:.4f}| / {loss_a_fixed:.4f} = {delta:.4f}') assert delta <= RELOAD_LOSS_TOLERANCE, ( f'Phase C FAILED: reload delta {delta:.4f} > tolerance {RELOAD_LOSS_TOLERANCE}') # --------------------------------------------------------------------- - # Phase D — RESUME VERIFY (resume_from_checkpoint(ckpt_b), train 2 more) + # Phase D — RESUME VERIFY # --------------------------------------------------------------------- logger.info('=' * 60) - logger.info('Phase D: resume-verify (new handle, resume ckpt_b, train %d steps)', STEPS_PHASE_D) + logger.info('Phase D: resume-verify (resume ckpt_b, train %d steps)', STEPS_PHASE_D) logger.info('=' * 60) - model_d = _configure_model('default') dataloader_d = _build_dataset_loader() - progress = model_d.resume_from_checkpoint(ckpt_b) + progress = model_a.resume_from_checkpoint(ckpt_b) logger.info(f'Phase D progress after resume: {progress}') resume_start = int(progress.get('cur_step', STEPS_PHASE_A + STEPS_PHASE_B)) if isinstance(progress, dict) else STEPS_PHASE_A + STEPS_PHASE_B - losses_d = _train_n_steps(model_d, dataloader_d, STEPS_PHASE_D, label='D', start_step=resume_start) + losses_d = _train_n_steps(model_a, dataloader_d, STEPS_PHASE_D, label='D', start_step=resume_start) last_b_loss = losses_b[-1][1] if losses_b else float('inf') for step_d, loss_d in losses_d: assert loss_d < last_b_loss * RESUME_LOSS_BAND, ( @@ -307,11 +275,6 @@ def main() -> int: # --------------------------------------------------------------------- # Phase E + F — SAMPLER greedy probe (base vs adapter, multi-prompt) # --------------------------------------------------------------------- - # Greedy on multiple prompts gives a decisive verdict that the on-disk - # LoRA artifact actually takes effect at inference time. Assertion: - # at least one prompt MUST diverge (token-level) between base and - # adapter — otherwise vLLM has silently fallen back to the base model - # (e.g. PEFT-format / target_modules mismatch in vllm's LoRA loader). logger.info('=' * 60) logger.info('Phase E + F: greedy probe across %d prompts (base vs adapter_uri=%s)', len(PROBE_PROMPTS), ckpt_b) logger.info('=' * 60) @@ -332,20 +295,19 @@ def main() -> int: n_diverged = sum(1 for *_, div in probe_results if div is not None) assert n_diverged >= 1, (f'Phase F FAILED: vLLM LoRA had no observable effect on any of ' f'{len(PROBE_PROMPTS)} probe prompts under greedy decoding — ' - f'either the adapter was not applied or training was too short ' - f'(needs ||B@A|| > ~0.1; check LoRA magnitude in adapter_model.safetensors).') + f'either the adapter was not applied or training was too short.') logger.info(f'Phase F OK: vLLM LoRA observably applied on {n_diverged}/{len(PROBE_PROMPTS)} prompts') # --------------------------------------------------------------------- # Summary # --------------------------------------------------------------------- logger.info('=' * 60) - logger.info('SUMMARY') + logger.info('SUMMARY (backend=%s)', BACKEND) logger.info('=' * 60) - logger.info(' Phase A losses (%d steps): %s', len(losses_a), [f'{l:.3f}' for _, l in losses_a]) + logger.info(' Phase A losses (%d steps): first=%.3f last=%.3f', len(losses_a), losses_a[0][1], losses_a[-1][1]) logger.info(' Phase B losses (%d steps): %s', len(losses_b), [f'{l:.3f}' for _, l in losses_b]) - logger.info(' Phase C reload: |%.4f - %.4f| / %.4f = %.4f (tol %.2f)', loss_c_fixed, loss_a_fixed, loss_a_fixed, - delta, RELOAD_LOSS_TOLERANCE) + logger.info(' Phase C reload: |%.4f - %.4f| / %.4f = %.4f (tol %.2f)', loss_c_fixed, loss_a_fixed, + loss_a_fixed, delta, RELOAD_LOSS_TOLERANCE) logger.info(' Phase D resume losses (%d steps): %s', len(losses_d), [f'{l:.3f}' for _, l in losses_d]) logger.info(' Phase F LoRA-effect probes (%d/%d diverged):', n_diverged, len(PROBE_PROMPTS)) for prompt, _, _, div in probe_results: diff --git a/tests/server/start_e2e_server.py b/tests/server/start_e2e_server.py new file mode 100644 index 000000000..f9af1d9ab --- /dev/null +++ b/tests/server/start_e2e_server.py @@ -0,0 +1,155 @@ +"""One-click: restart Ray cluster + launch Twinkle server + wait until ready. + +Usage: + python start_e2e_server.py # default config + python start_e2e_server.py --config my_config.yaml # custom config + python start_e2e_server.py --kill-only # just kill everything +""" +import argparse +import os +import signal +import subprocess +import sys +import time + +import requests + +# ── Paths ── +RAY = "/mnt/nas2/anaconda3/envs/tinker_myl/bin/ray" +PYTHON = "/mnt/nas2/anaconda3/envs/tinker_myl/bin/python" +WORKDIR = "/mnt/nas2/yunlin.myl/twinkle" +DEFAULT_CONFIG = "server_config_4b_e2e.yaml" +RAY_TEMP_DIR = "/mnt/nas2/yunlin.myl/ray_logs" +SERVER_LOG = os.path.join(WORKDIR, "server_e2e.log") + +# ── Server check ── +SERVER_URL = "http://localhost:9000/-/routes" +READY_KEYWORD = "processor" +TIMEOUT = 180 +POLL_INTERVAL = 5 + + +def run(cmd, env=None, check=True): + """Run a shell command, print it, and return CompletedProcess.""" + full_env = os.environ.copy() + if env: + full_env.update(env) + print(f" $ {cmd}") + result = subprocess.run(cmd, shell=True, env=full_env, + capture_output=True, text=True) + if result.returncode != 0 and check: + print(f" STDERR: {result.stderr.strip()}") + return result + + +def kill_server(): + """Kill any existing twinkle.server processes.""" + print("[1/4] Killing existing Twinkle server processes...") + result = subprocess.run( + "ps aux | grep 'twinkle.server' | grep -v grep | awk '{print $2}'", + shell=True, capture_output=True, text=True + ) + pids = result.stdout.strip().split() + for pid in pids: + try: + os.kill(int(pid), signal.SIGKILL) + print(f" Killed PID {pid}") + except (ProcessLookupError, ValueError): + pass + if not pids: + print(" No existing server found") + time.sleep(2) + + +def restart_ray(): + """Stop and restart the Ray cluster (3 GPU + 1 CPU nodes).""" + print("[2/4] Restarting Ray cluster...") + run(f"{RAY} stop --force", check=False) + time.sleep(2) + + # Head node: GPU 0,1 + run(f"{RAY} start --head --port=6379 --num-gpus=2 " + f"--disable-usage-stats --temp-dir={RAY_TEMP_DIR}", + env={"CUDA_VISIBLE_DEVICES": "0,1"}) + + # Worker: GPU 2 + run(f"{RAY} start --address=127.0.0.1:6379 --num-gpus=1", + env={"CUDA_VISIBLE_DEVICES": "2"}) + + # CPU-only worker + run(f"{RAY} start --address=127.0.0.1:6379 --num-gpus=0", + env={"CUDA_VISIBLE_DEVICES": ""}) + + print(" Ray cluster started (2+1+0 GPUs)") + + +def launch_server(config: str): + """Launch the Twinkle server in the background.""" + print(f"[3/4] Launching Twinkle server (config={config})...") + config_path = os.path.join(WORKDIR, config) + if not os.path.isfile(config_path): + print(f" ERROR: config not found: {config_path}", file=sys.stderr) + sys.exit(1) + + cmd = f"{PYTHON} -m twinkle.server launch --config {config_path}" + log_fd = open(SERVER_LOG, "w") + subprocess.Popen( + cmd, shell=True, cwd=WORKDIR, stdout=log_fd, stderr=log_fd, + env={**os.environ, "TWINKLE_TRUST_REMOTE_CODE": "1"}, + start_new_session=True, + ) + print(f" Server starting (log: {SERVER_LOG})") + + +def wait_ready(): + """Poll until the server is fully ready.""" + print(f"[4/4] Waiting for server to be ready (timeout={TIMEOUT}s)...") + start = time.time() + while time.time() - start < TIMEOUT: + try: + resp = requests.get(SERVER_URL, timeout=3) + if resp.ok and READY_KEYWORD in resp.text: + elapsed = time.time() - start + print(f" Server READY ({elapsed:.0f}s)") + return True + except (requests.ConnectionError, requests.Timeout): + pass + time.sleep(POLL_INTERVAL) + + print(f" TIMEOUT: server not ready after {TIMEOUT}s", file=sys.stderr) + print(f" Check log: tail -50 {SERVER_LOG}", file=sys.stderr) + return False + + +def main(): + parser = argparse.ArgumentParser(description="Restart Ray + launch Twinkle server") + parser.add_argument("--config", default=DEFAULT_CONFIG, + help=f"Server config yaml (default: {DEFAULT_CONFIG})") + parser.add_argument("--kill-only", action="store_true", + help="Only kill server + Ray, don't restart") + parser.add_argument("--no-ray", action="store_true", + help="Skip Ray restart (server only)") + args = parser.parse_args() + + kill_server() + + if args.kill_only: + run(f"{RAY} stop --force", check=False) + print("Done (kill-only)") + return 0 + + if not args.no_ray: + restart_ray() + + launch_server(args.config) + + if wait_ready(): + print("\n✓ All set. Run your test:") + print(f" TWINKLE_TEST_GPU_E2E=1 {PYTHON} -u tests/server/integration/test_full_cycle_e2e.py") + return 0 + else: + return 1 + + +if __name__ == "__main__": + sys.exit(main()) From defeec2aa278c4f6d6c1e88059c24ff9c7e622e2 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 25 Jun 2026 15:40:35 +0800 Subject: [PATCH 06/16] feat: add deep health check to detect model actor death - Add ping() method to all model backends (mock/transformers/megatron) - Add check_model_health() to ModelManagement that calls ping() - Add /healthz endpoint on model deployment (returns 503 if actors dead) - Add /twinkle/healthz/deep on gateway that probes all model deployments - Update entrypoint.sh watchdog to check deep health URL - Update validation middleware to skip sticky session for /healthz/* paths - Fix start_e2e_server.py DEFAULT_CONFIG path --- cookbook/client/server/megatron/entrypoint.sh | 14 +++++-- .../server/gateway/twinkle_handlers.py | 39 +++++++++++++++++++ src/twinkle/server/model/app.py | 15 +++++++ .../server/model/backends/megatron_model.py | 10 +++++ .../server/model/backends/mock_model.py | 5 +++ .../model/backends/transformers_model.py | 5 +++ src/twinkle/server/model/twinkle_handlers.py | 12 ++++++ src/twinkle/server/utils/validation.py | 2 +- tests/server/start_e2e_server.py | 2 +- 9 files changed, 99 insertions(+), 5 deletions(-) diff --git a/cookbook/client/server/megatron/entrypoint.sh b/cookbook/client/server/megatron/entrypoint.sh index 34a2bfa5e..6f55d4e68 100755 --- a/cookbook/client/server/megatron/entrypoint.sh +++ b/cookbook/client/server/megatron/entrypoint.sh @@ -11,10 +11,12 @@ TWINKLE_WORK_DIR="${TWINKLE_WORK_DIR:-/dashscope/caches/application/twinkle}" TEMP_DIR="${TWINKLE_TEMP_DIR:-/dashscope/caches/application/ray_logs}" LOG_FILE="$TWINKLE_WORK_DIR/run.log" TWINKLE_HEALTH_URL="${TWINKLE_HEALTH_URL:-http://127.0.0.1:9000/api/v1/healthz}" +TWINKLE_DEEP_HEALTH_URL="${TWINKLE_DEEP_HEALTH_URL:-http://127.0.0.1:9000/api/v1/twinkle/healthz/deep}" TWINKLE_WATCHDOG_INTERVAL_SECONDS="${TWINKLE_WATCHDOG_INTERVAL_SECONDS:-10}" TWINKLE_WATCHDOG_FAILURE_THRESHOLD="${TWINKLE_WATCHDOG_FAILURE_THRESHOLD:-3}" TWINKLE_RAY_GRACE_SECONDS="${TWINKLE_RAY_GRACE_SECONDS:-30}" TWINKLE_HEALTH_GRACE_SECONDS="${TWINKLE_HEALTH_GRACE_SECONDS:-${TWINKLE_WATCHDOG_STARTUP_GRACE_SECONDS:-300}}" +TWINKLE_DEEP_HEALTH_GRACE_SECONDS="${TWINKLE_DEEP_HEALTH_GRACE_SECONDS:-${TWINKLE_HEALTH_GRACE_SECONDS:-300}}" RESTART_BACKOFF_SECONDS="${TWINKLE_ENTRYPOINT_RESTART_BACKOFF_SECONDS:-10}" CHILD_PID="" @@ -58,6 +60,7 @@ validate_entrypoint_config() { require_positive_int "TWINKLE_WATCHDOG_FAILURE_THRESHOLD" "$TWINKLE_WATCHDOG_FAILURE_THRESHOLD" require_non_negative_int "TWINKLE_RAY_GRACE_SECONDS" "$TWINKLE_RAY_GRACE_SECONDS" require_non_negative_int "TWINKLE_HEALTH_GRACE_SECONDS" "$TWINKLE_HEALTH_GRACE_SECONDS" + require_non_negative_int "TWINKLE_DEEP_HEALTH_GRACE_SECONDS" "$TWINKLE_DEEP_HEALTH_GRACE_SECONDS" require_non_negative_int "TWINKLE_ENTRYPOINT_RESTART_BACKOFF_SECONDS" "$RESTART_BACKOFF_SECONDS" require_command timeout @@ -72,13 +75,14 @@ validate_entrypoint_config() { } check_http_health() { + local url="${1:-$TWINKLE_HEALTH_URL}" if command -v curl &> /dev/null; then - curl -fsS --max-time 10 "$TWINKLE_HEALTH_URL" >/dev/null + curl -fsS --max-time 10 "$url" >/dev/null return fi if command -v wget &> /dev/null; then - wget -q --spider --timeout=10 "$TWINKLE_HEALTH_URL" + wget -q --spider --timeout=10 "$url" return fi @@ -86,7 +90,7 @@ check_http_health() { if ! command -v "$python_bin" &> /dev/null; then python_bin="python" fi - "$python_bin" - "$TWINKLE_HEALTH_URL" <<'PY' + "$python_bin" - "$url" <<'PY' import sys import urllib.request @@ -102,6 +106,7 @@ PY print_watchdog_diagnostics() { print_warning "EntryPoint watchdog 诊断信息:" echo " - health url: $TWINKLE_HEALTH_URL" + echo " - deep health url: $TWINKLE_DEEP_HEALTH_URL" echo " - run.sh pid: ${CHILD_PID:-unset}" echo " - Ray logs: $TEMP_DIR/session_latest/logs" @@ -161,6 +166,9 @@ while true; do elif ! check_http_health; then WATCHDOG_FAILURE_REASON="http health check failed: $TWINKLE_HEALTH_URL" WATCHDOG_GRACE_SECONDS="$TWINKLE_HEALTH_GRACE_SECONDS" + elif ! check_http_health "$TWINKLE_DEEP_HEALTH_URL"; then + WATCHDOG_FAILURE_REASON="deep health check failed (model actors may be dead): $TWINKLE_DEEP_HEALTH_URL" + WATCHDOG_GRACE_SECONDS="$TWINKLE_DEEP_HEALTH_GRACE_SECONDS" fi if [ -z "$WATCHDOG_FAILURE_REASON" ]; then diff --git a/src/twinkle/server/gateway/twinkle_handlers.py b/src/twinkle/server/gateway/twinkle_handlers.py index 77cc67c62..107c00c53 100644 --- a/src/twinkle/server/gateway/twinkle_handlers.py +++ b/src/twinkle/server/gateway/twinkle_handlers.py @@ -36,6 +36,45 @@ async def get_capacity_info( async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') + @app.get('/twinkle/healthz/deep') + async def healthz_deep( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> dict: + """Deep health check: verifies model actors are alive, not just the gateway. + + Returns 503 if any model deployment's actors are unreachable (e.g. OOM/SIGSEGV). + The entrypoint watchdog should poll this endpoint to detect silent failures. + """ + from fastapi.responses import JSONResponse + + results = {} + all_healthy = True + + for model in self.supported_models: + model_name = model.model_name + try: + resp = await self.proxy.proxy_request( + request, 'healthz', model_name, 'model') + healthy = (resp.status_code == 200) + if not healthy: + all_healthy = False + results[model_name] = { + 'healthy': healthy, + 'status_code': resp.status_code, + } + except Exception as e: + all_healthy = False + results[model_name] = { + 'healthy': False, + 'detail': str(e), + } + + body = {'healthy': all_healthy, 'models': results} + if not all_healthy: + return JSONResponse(status_code=503, content=body) + return body + @app.get('/twinkle/get_server_capabilities', response_model=types.GetServerCapabilitiesResponse) async def get_server_capabilities( request: Request, diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 261e34de9..bae7b2a93 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -147,6 +147,21 @@ async def shutdown(self) -> None: except Exception: pass + def check_model_health(self) -> dict: + """Probe model actors liveness via a lightweight ping. + + Returns a dict with 'healthy' (bool) and 'detail' (str). + If the model actors are dead (e.g. OOM/SIGSEGV), the ping call + will raise RayActorError, signalling the watchdog to restart. + """ + try: + result = self.model.ping() + if result is True: + return {'healthy': True, 'detail': 'model actors alive'} + return {'healthy': False, 'detail': f'unexpected ping result: {result}'} + except Exception as e: + return {'healthy': False, 'detail': f'model actor unreachable: {e}'} + async def _cleanup_adapter(self, adapter_name: str) -> None: if self.get_resource_info(adapter_name): self.clear_resource_state(adapter_name) diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 35584508d..635d981ba 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -13,6 +13,7 @@ from twinkle.server.common.datum import datum_to_input_feature, extract_rl_features_for_loss from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results, to_cpu_safe_output) +from twinkle.utils.nccl_safe import nccl_safe_megatron @remote_class(execute='all') @@ -23,6 +24,7 @@ class TwinkleCompatMegatronModel(MultiLoraMegatronModel, TwinkleCompatModelBase) """ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True) + @nccl_safe_megatron(tinker=True) def tinker_forward_backward(self, *, inputs: list[types.Datum], adapter_name: str, loss_fn: str, **kwargs): """Combined forward and backward pass.""" self._tinker_setup_loss(loss_fn, inputs, adapter_name, kwargs) @@ -43,6 +45,7 @@ def tinker_forward_backward(self, *, inputs: list[types.Datum], adapter_name: st return [results, loss] @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) + @nccl_safe_megatron(tinker=True) def tinker_forward_only(self, *, inputs: list[types.Datum], adapter_name: str = None, **kwargs): """Forward pass without gradient computation.""" template = self.get_template(adapter_name) @@ -97,13 +100,20 @@ def tinker_load(self, checkpoint_dir: str, **kwargs): # ------------------------------------------------------------------ @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + @nccl_safe_megatron(forward_only=True) def forward_only(self, *, inputs: InputFeature | list[InputFeature] | Trajectory | list[Trajectory], **kwargs): """Forward-only for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_only(inputs=inputs, **kwargs) return to_cpu_safe_output(output) @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + @nccl_safe_megatron def forward_backward(self, *, inputs: InputFeature | list[InputFeature] | Trajectory | list[Trajectory], **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) return to_cpu_safe_output(output) + + @remote_function(collect='first', lazy_collect=False) + def ping(self) -> bool: + """Lightweight liveness probe for watchdog health checks.""" + return True diff --git a/src/twinkle/server/model/backends/mock_model.py b/src/twinkle/server/model/backends/mock_model.py index ac89935c0..bc0fc06f2 100644 --- a/src/twinkle/server/model/backends/mock_model.py +++ b/src/twinkle/server/model/backends/mock_model.py @@ -235,6 +235,11 @@ def remove_adapter(self, adapter_name: str) -> None: def has_adapter(self, adapter_name: str) -> bool: return adapter_name in self._adapters + @remote_function(collect='first', lazy_collect=False) + def ping(self) -> bool: + """Lightweight liveness probe for watchdog health checks.""" + return True + def _to_tinker_loss_outputs(records: list[dict[str, Any]]) -> list[dict[str, Any]]: """Wrap mock's plain numpy-derived dicts into ``tinker.TensorData`` instances. diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index 5b90c202c..e7677b619 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -105,3 +105,8 @@ def forward_backward(self, *, inputs: InputFeature | list[InputFeature] | Trajec """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) return to_cpu_safe_output(output) + + @remote_function(collect='first', lazy_collect=False) + def ping(self) -> bool: + """Lightweight liveness probe for watchdog health checks.""" + return True diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 5f1fc02d8..2074f5e40 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -62,6 +62,18 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], ModelManagement replica instance. It is wired in via Depends so it is resolved lazily at request time. """ + @app.get('/healthz') + async def model_healthz( + request: Request, + self: ModelManagement = Depends(self_fn), + ) -> dict: + """Deep health probe: pings underlying model actors to verify liveness.""" + result = self.check_model_health() + if not result['healthy']: + from fastapi.responses import JSONResponse + return JSONResponse(status_code=503, content=result) + return result + async def run_task(coro): """Await a schedule_task_and_wait coroutine and surface any exception as a structured HTTP 500 response so the client receives the full traceback instead diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index b222eb9be..ef779cd27 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -32,7 +32,7 @@ async def verify_request_token(request: Request, call_next): return JSONResponse(status_code=403, content={'detail': 'Invalid token'}) path = request.url.path - skip_sticky = (path.endswith('/healthz') or any(path.endswith(s) for s in _OPENAI_COMPAT_SUFFIXES)) + skip_sticky = ('/healthz' in path or any(path.endswith(s) for s in _OPENAI_COMPAT_SUFFIXES)) if not skip_sticky: request_id = request.headers.get(H_REQUEST_ID) diff --git a/tests/server/start_e2e_server.py b/tests/server/start_e2e_server.py index f9af1d9ab..dcf6dd3f4 100644 --- a/tests/server/start_e2e_server.py +++ b/tests/server/start_e2e_server.py @@ -18,7 +18,7 @@ RAY = "/mnt/nas2/anaconda3/envs/tinker_myl/bin/ray" PYTHON = "/mnt/nas2/anaconda3/envs/tinker_myl/bin/python" WORKDIR = "/mnt/nas2/yunlin.myl/twinkle" -DEFAULT_CONFIG = "server_config_4b_e2e.yaml" +DEFAULT_CONFIG = "tests/server/config/server_config_4b_e2e.yaml" RAY_TEMP_DIR = "/mnt/nas2/yunlin.myl/ray_logs" SERVER_LOG = os.path.join(WORKDIR, "server_e2e.log") From e6fd6658231f3e6e0499163b6620ca148a5b0372 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 25 Jun 2026 16:43:14 +0800 Subject: [PATCH 07/16] update nccl safe --- .../model/transformers/transformers.py | 2 - src/twinkle/utils/nccl_safe.py | 77 ++++- .../server/config/server_config_4b_e2e.yaml | 2 +- .../config/server_config_4b_e2e_megatron.yaml | 2 +- ...rver_config_4b_e2e_megatron_failfast1.yaml | 124 +++++++ .../test_dpo_nccl_safe_megatron.py | 184 +++++++++++ .../server/integration/test_race_nccl_hang.py | 251 ++++++++++++++ tests/server/test_nccl_safe.py | 308 ++++++++++++++++++ 8 files changed, 945 insertions(+), 5 deletions(-) rename server_config_4b_e2e.yaml => tests/server/config/server_config_4b_e2e.yaml (98%) rename server_config_4b_e2e_megatron.yaml => tests/server/config/server_config_4b_e2e_megatron.yaml (98%) create mode 100644 tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml create mode 100644 tests/server/integration/test_dpo_nccl_safe_megatron.py create mode 100644 tests/server/integration/test_race_nccl_hang.py diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 6abcaf3cc..61733d7dc 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -42,7 +42,6 @@ from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm -from twinkle.utils.nccl_safe import nccl_safe from twinkle.utils.transformers_utils import filter_from_config_kwargs logger = get_logger() @@ -622,7 +621,6 @@ def backward(self, **kwargs): optimizer_config.train_status.loss_value = None @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) - @nccl_safe def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Do forward, calculate loss, and backward. diff --git a/src/twinkle/utils/nccl_safe.py b/src/twinkle/utils/nccl_safe.py index 49812fd54..5311ee793 100644 --- a/src/twinkle/utils/nccl_safe.py +++ b/src/twinkle/utils/nccl_safe.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """NCCL-safe utilities for production distributed training. -Provides two layers of protection to prevent NCCL hangs: +Provides three layers of protection to prevent NCCL hangs: Layer 1 - safe_loss(): Wraps loss instances to catch computation errors and return @@ -11,6 +11,12 @@ Wraps forward_backward methods to ensure backward() always executes after forward() has started, even if intermediate code raises. +Layer 3 - @nccl_safe_megatron decorator: + Wraps Megatron backend methods (forward_only, forward_backward) where + the entire function body involves NCCL communication (sync=True). + Catches pre-communication errors (e.g. data preprocessing failures) + that would otherwise leave other DP ranks waiting at a collective. + Controlled by environment variable: TWINKLE_FAIL_FAST=1 (default, development): all protection is transparent, exceptions propagate normally. @@ -251,3 +257,72 @@ def _force_zero_backward(model, og, adapter_name, kwargs): if gas is not None: bwd_kwargs['gradient_accumulation_steps'] = gas model.backward(**bwd_kwargs) + + +# ─── Layer 3: @nccl_safe_megatron decorator ────────────────────────────────── + + +def nccl_safe_megatron(func=None, *, tinker=False, forward_only=False): + """Decorator for Megatron backend methods where the entire body is NCCL-critical. + + Unlike @nccl_safe (which detects forward/backward boundaries), this decorator + treats the **entire function** as a NCCL-critical section. In Megatron, + forward_only and forward_backward both call get_forward_backward_func() which + requires all DP ranks to enter synchronously. If one rank fails during data + preprocessing (before entering Megatron's scheduler), other ranks will hang + waiting for the collective. + + This decorator catches ALL exceptions (when TWINKLE_FAIL_FAST=0) and returns + a safe fallback value, preventing NCCL hang from asymmetric failures. + + Args: + func: The function to decorate (when used without arguments). + tinker: If True, fallback returns ``[[], 0.0]`` (tinker format). + forward_only: If True, fallback returns empty dict ``{}`` (forward_only format). + + Usage:: + + @remote_function(dispatch='slice_dp', collect=..., sync=True) + @nccl_safe_megatron + def forward_backward(self, *, inputs, **kwargs): + ... + + @remote_function(dispatch='slice_dp', collect=...) + @nccl_safe_megatron(forward_only=True) + def forward_only(self, *, inputs, **kwargs): + ... + + @remote_function(dispatch='slice_dp', collect=..., sync=True) + @nccl_safe_megatron(tinker=True) + def tinker_forward_backward(self, *, inputs, **kwargs): + ... + """ + + def decorator(fn): + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + if _is_fail_fast(): + return fn(self, *args, **kwargs) + + try: + return fn(self, *args, **kwargs) + except Exception as e: + logger.warning(f'[nccl_safe_megatron] Exception in Megatron method ' + f'{fn.__name__}: {type(e).__name__}: {e}') + + # Return safe fallback to prevent NCCL hang on other ranks + if tinker: + return [[], 0.0] + if forward_only: + return {} + # forward_backward fallback: return dict with loss=0.0 + return {'loss': 0.0} + + return wrapper + + if func is not None: + # @nccl_safe_megatron without arguments + return decorator(func) + # @nccl_safe_megatron(tinker=True) with arguments + return decorator diff --git a/server_config_4b_e2e.yaml b/tests/server/config/server_config_4b_e2e.yaml similarity index 98% rename from server_config_4b_e2e.yaml rename to tests/server/config/server_config_4b_e2e.yaml index 57a30476f..6a9dfd317 100644 --- a/server_config_4b_e2e.yaml +++ b/tests/server/config/server_config_4b_e2e.yaml @@ -13,7 +13,7 @@ applications: import_path: server args: server_config: - per_token_model_limit: 3 + per_token_model_limit: 30 supported_models: - Qwen/Qwen3.5-4B deployments: diff --git a/server_config_4b_e2e_megatron.yaml b/tests/server/config/server_config_4b_e2e_megatron.yaml similarity index 98% rename from server_config_4b_e2e_megatron.yaml rename to tests/server/config/server_config_4b_e2e_megatron.yaml index 4a458fc0f..10a6b5ce3 100644 --- a/server_config_4b_e2e_megatron.yaml +++ b/tests/server/config/server_config_4b_e2e_megatron.yaml @@ -13,7 +13,7 @@ applications: import_path: server args: server_config: - per_token_model_limit: 3 + per_token_model_limit: 30 supported_models: - Qwen/Qwen3.5-4B deployments: diff --git a/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml new file mode 100644 index 000000000..d8103b938 --- /dev/null +++ b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml @@ -0,0 +1,124 @@ +# Twinkle Server Configuration - E2E Test (4B model, Megatron backend) + +proxy_location: EveryNode + +http_options: + host: 0.0.0.0 + port: 9000 + +applications: + + - name: server + route_prefix: /api/v1 + import_path: server + args: + server_config: + per_token_model_limit: 20 + supported_models: + - Qwen/Qwen3.5-4B + deployments: + - name: TinkerCompatServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" + + - name: models-Qwen3.5-4B + route_prefix: /api/v1/model/Qwen/Qwen3.5-4B + import_path: model + args: + backend: megatron + model_id: "ms://Qwen/Qwen3.5-4B" + max_length: 10240 + nproc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 + queue_config: + rps_limit: 100 + tps_limit: 100000 + adapter_config: + adapter_timeout: 30 + adapter_max_lifetime: 36000 + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" + + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" + nproc_per_node: 1 + sampler_type: vllm + engine_args: + max_model_len: 4096 + gpu_memory_utilization: 0.5 + enable_lora: true + max_loras: 5 + logprobs_mode: processed_logprobs + device_group: + name: sampler + ranks: 1 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" + + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" diff --git a/tests/server/integration/test_dpo_nccl_safe_megatron.py b/tests/server/integration/test_dpo_nccl_safe_megatron.py new file mode 100644 index 000000000..96831bf78 --- /dev/null +++ b/tests/server/integration/test_dpo_nccl_safe_megatron.py @@ -0,0 +1,184 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""DPO NCCL-safe verification test for Megatron backend. + +Tests that DPO training (forward_only + forward_backward) does NOT cause +NCCL hang when errors occur, verifying the nccl_safe_megatron fix. + +Prerequisites: + 1. Ray cluster running with GPUs + 2. Twinkle server started with Megatron backend and TWINKLE_FAIL_FAST=0 + +Usage (direct): + python tests/server/integration/test_dpo_nccl_safe_megatron.py +""" +from __future__ import annotations + +import os +import sys +import time + +import numpy as np + +SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +BASE_MODEL = 'Qwen/Qwen3.5-4B' +TIMEOUT = 120 + + +def log(msg): + print(f'[DPO-Megatron] {msg}', flush=True) + + +def wait_for_server(url, timeout=300): + import requests + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f'{url}/-/routes', timeout=5) + if resp.status_code == 200: + log(f'Server ready ({int(time.time() - start)}s)') + return True + except Exception: + pass + time.sleep(5) + raise TimeoutError(f'Server not ready after {timeout}s') + + +def init_dpo_client(): + """Initialize Twinkle client for DPO training.""" + from twinkle_client import init_twinkle_client + from twinkle_client.model import MultiLoraTransformersModel + from peft import LoraConfig + + init_twinkle_client(base_url=SERVER_URL, api_key='EMPTY_TOKEN') + + model = MultiLoraTransformersModel(model_id=f'ms://{BASE_MODEL}') + model.add_adapter_to_model( + adapter_name='dpo-test', + config=LoraConfig(r=16, target_modules=['q_proj', 'v_proj']), + gradient_accumulation_steps=1, + ) + model.set_loss('DPOLoss', init_args={'beta': 0.1}) + model.set_optimizer('Adam', lr=1e-5) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + log('DPO client configured') + return model + + +def make_dpo_batch(batch_size=4, seq_len=64, completion_len=32): + """Create DPO batch: interleaved chosen/rejected pairs.""" + prompt_len = seq_len - completion_len + all_inputs = [] + + # DPO requires even batch (chosen/rejected pairs) + for i in range(batch_size): + input_ids = list(range(1, seq_len + 1)) + labels = [-100] * prompt_len + list(range(100, 100 + completion_len)) + all_inputs.append({ + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': [1] * seq_len, + }) + + return all_inputs + + +def run_dpo_step(model, test_name, *, bad_ref_logps=False): + """Execute a full DPO step: forward_only (ref) + forward_backward (policy).""" + batch = make_dpo_batch(batch_size=4) + + # Step 1: forward_only for reference logps + log(f'[{test_name}] forward_only (reference)...') + start = time.time() + try: + ref_result = model.forward_only(inputs=batch, disable_lora=True) + elapsed = time.time() - start + log(f'[{test_name}] forward_only OK ({elapsed:.1f}s)') + except Exception as e: + elapsed = time.time() - start + log(f'[{test_name}] forward_only FAILED ({elapsed:.1f}s): {e}') + if elapsed > TIMEOUT: + log(f'[{test_name}] TIMEOUT! NCCL HANG detected!') + return False + return True # Error was caught, not a hang + + # Step 2: forward_backward for policy training + log(f'[{test_name}] forward_backward (policy)...') + start = time.time() + try: + # For DPO, we need ref_logps from the reference forward + kwargs = {} + if hasattr(ref_result, 'result') and ref_result.result: + # Extract ref_logps from forward_only result + pass # In real DPO, client passes ref_logps + result = model.forward_backward(inputs=batch, **kwargs) + elapsed = time.time() - start + log(f'[{test_name}] forward_backward OK ({elapsed:.1f}s)') + except Exception as e: + elapsed = time.time() - start + log(f'[{test_name}] forward_backward FAILED ({elapsed:.1f}s): {e}') + if elapsed > TIMEOUT: + log(f'[{test_name}] TIMEOUT! NCCL HANG detected!') + return False + return True # Error was caught, not a hang + + # Step 3: optimizer step + try: + model.clip_grad_and_step() + log(f'[{test_name}] clip_grad_and_step OK') + except Exception as e: + log(f'[{test_name}] clip_grad_and_step FAILED: {e}') + + return True + + +def main(): + log('=' * 60) + log('DPO NCCL-Safe Verification - Megatron Backend') + log('=' * 60) + log(f'Server URL: {SERVER_URL}') + log(f'TWINKLE_FAIL_FAST = {os.getenv("TWINKLE_FAIL_FAST", "1 (default)")}') + + wait_for_server(SERVER_URL) + model = init_dpo_client() + + results = [] + + # Test 1: Normal DPO training (should work) + passed = run_dpo_step(model, 'TEST-1-NORMAL-DPO') + results.append(('TEST-1: Normal DPO', passed)) + + # Test 2: Multiple consecutive DPO steps + for i in range(3): + passed = run_dpo_step(model, f'TEST-2-CONSECUTIVE-{i+1}') + results.append((f'TEST-2-{i+1}: Consecutive DPO', passed)) + + # Test 3: forward_only then forward_backward rapidly + passed = run_dpo_step(model, 'TEST-3-RAPID') + results.append(('TEST-3: Rapid DPO', passed)) + + # Test 4: Health check after all DPO operations + log('[TEST-4] Final health check - forward_backward...') + batch = make_dpo_batch(batch_size=4) + start = time.time() + try: + model.forward_backward(inputs=batch) + elapsed = time.time() - start + log(f'[TEST-4] OK ({elapsed:.1f}s)') + results.append(('TEST-4: Final health', True)) + except Exception as e: + elapsed = time.time() - start + log(f'[TEST-4] FAILED ({elapsed:.1f}s): {e}') + results.append(('TEST-4: Final health', elapsed < TIMEOUT)) + + # Summary + log(f'\n{"=" * 60}\nRESULTS SUMMARY\n{"=" * 60}') + all_passed = all(p for _, p in results) + for name, status in results: + log(f' [{"PASS" if status else "FAIL"}] {name}') + log(f'\n{"ALL" if all_passed else "SOME"} {len(results)} TESTS {"PASSED" if all_passed else "FAILED"}!') + return 0 if all_passed else 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/integration/test_race_nccl_hang.py b/tests/server/integration/test_race_nccl_hang.py new file mode 100644 index 000000000..3072a7aae --- /dev/null +++ b/tests/server/integration/test_race_nccl_hang.py @@ -0,0 +1,251 @@ +"""Multi-client race condition stress test: reproduce ncclCommSplit hang. + +Strategy: 3 concurrent clients hammer the Megatron server in a loop: +- Client A: continuous DPO training (forward_only + forward_backward + step) +- Client B: continuous DPO training with occasional bad data (triggers errors) +- Client C: repeatedly creates/destroys adapters (triggers ncclCommSplit) + +Runs multiple rounds. If any operation takes >TIMEOUT seconds, NCCL hang is detected. + +Usage: + python tests/server/integration/test_race_nccl_hang.py + python tests/server/integration/test_race_nccl_hang.py --rounds 5 +""" +import argparse +import os +import sys +import threading +import time + +os.environ['TINKER_BASE_URL'] = 'http://localhost:9000' +os.environ['TWINKLE_SERVER_TOKEN'] = 'EMPTY_TOKEN' + +from twinkle_client import init_twinkle_client +from twinkle_client.model import MultiLoraTransformersModel +from peft import LoraConfig + +SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +TIMEOUT = 90 +seq_len = 64 + + +def make_batch(size=4, include_position_ids=True): + batch = [] + for _ in range(size): + item = { + 'input_ids': list(range(1, seq_len + 1)), + 'labels': [-100] * 32 + list(range(100, 132)), + 'attention_mask': [1] * seq_len, + } + if include_position_ids: + item['position_ids'] = list(range(seq_len)) + batch.append(item) + return batch + + +results = {'hangs': [], 'errors': [], 'steps_ok': 0, 'rounds_ok': 0} +lock = threading.Lock() +stop_event = threading.Event() + + +def log(msg): + print(f'[RACE] {msg}', flush=True) + + +def create_session(name): + init_twinkle_client(base_url=SERVER_URL, api_key='EMPTY_TOKEN') + model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model.add_adapter_to_model( + adapter_name=name, + config=LoraConfig(r=16, target_modules=['q_proj', 'v_proj']), + gradient_accumulation_steps=1, + ) + model.set_loss('DPOLoss', init_args={'beta': 0.1}) + model.set_optimizer('Adam', lr=1e-5) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + return model + + +def client_a_training(steps_per_round): + """Client A: continuous normal DPO training.""" + try: + model = create_session('client-a') + log('Client-A: session ready') + batch = make_batch(4) + for i in range(steps_per_round): + if stop_event.is_set(): + return + start = time.time() + model.forward_only(inputs=batch, disable_lora=True) + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + elapsed = time.time() - start + with lock: + results['steps_ok'] += 1 + if i % 3 == 0: + log(f'Client-A: step {i + 1}/{steps_per_round} ({elapsed:.1f}s)') + if elapsed > TIMEOUT: + with lock: + results['hangs'].append(f'Client-A step {i + 1} ({elapsed:.0f}s)') + stop_event.set() + return + except Exception as e: + with lock: + results['errors'].append(f'Client-A: {type(e).__name__}: {str(e)[:80]}') + log(f'Client-A: ERROR {type(e).__name__}: {str(e)[:80]}') + + +def client_b_mixed_training(steps_per_round): + """Client B: training with mix of good and bad requests.""" + try: + model = create_session('client-b') + log('Client-B: session ready') + good_batch = make_batch(4) + bad_batch_no_pos = make_batch(4, include_position_ids=False) # missing position_ids + bad_batch_odd = make_batch(3) # odd size for DPO + + for i in range(steps_per_round): + if stop_event.is_set(): + return + start = time.time() + try: + # Every 4th request: send bad data to trigger error + if i % 4 == 3: + model.forward_backward(inputs=bad_batch_no_pos) + elif i % 7 == 6: + model.forward_backward(inputs=bad_batch_odd) + else: + model.forward_only(inputs=good_batch, disable_lora=True) + model.forward_backward(inputs=good_batch) + model.clip_grad_and_step() + elapsed = time.time() - start + with lock: + results['steps_ok'] += 1 + if i % 3 == 0: + log(f'Client-B: step {i + 1}/{steps_per_round} ({elapsed:.1f}s)') + except Exception: + elapsed = time.time() - start + if elapsed > TIMEOUT: + with lock: + results['hangs'].append(f'Client-B step {i + 1} ({elapsed:.0f}s)') + stop_event.set() + return + # Expected errors from bad data - continue + if i % 5 == 0: + log(f'Client-B: step {i + 1} error (expected, {elapsed:.1f}s)') + except Exception as e: + with lock: + results['errors'].append(f'Client-B: {type(e).__name__}: {str(e)[:80]}') + log(f'Client-B: ERROR {type(e).__name__}: {str(e)[:80]}') + + +def client_c_adapter_churn(count): + """Client C: repeatedly create adapters (triggers ncclCommSplit).""" + time.sleep(0.5) # let training start first + for i in range(count): + if stop_event.is_set(): + return + name = f'churn-{i}' + start = time.time() + try: + init_twinkle_client(base_url=SERVER_URL, api_key='EMPTY_TOKEN') + m = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + m.add_adapter_to_model( + adapter_name=name, + config=LoraConfig(r=16, target_modules=['q_proj', 'v_proj']), + gradient_accumulation_steps=1, + ) + elapsed = time.time() - start + if i % 2 == 0: + log(f'Client-C: adapter {name} OK ({elapsed:.1f}s)') + if elapsed > TIMEOUT: + with lock: + results['hangs'].append(f'Client-C {name} ({elapsed:.0f}s)') + stop_event.set() + return + except Exception as e: + elapsed = time.time() - start + if elapsed > TIMEOUT: + with lock: + results['hangs'].append(f'Client-C {name} ({elapsed:.0f}s)') + stop_event.set() + return + log(f'Client-C: {name} error ({elapsed:.1f}s) - continuing') + time.sleep(0.1) + + +def run_round(round_num, steps_per_round=10, adapter_churn=5): + """Run one round of the stress test.""" + global results + stop_event.clear() + + log(f'--- Round {round_num} (steps={steps_per_round}, churn={adapter_churn}) ---') + + t1 = threading.Thread(target=client_a_training, args=(steps_per_round,)) + t2 = threading.Thread(target=client_b_mixed_training, args=(steps_per_round,)) + t3 = threading.Thread(target=client_c_adapter_churn, args=(adapter_churn,)) + + threads = [t1, t2, t3] + for t in threads: + t.start() + + for t in threads: + t.join(timeout=300) + + alive = [t for t in threads if t.is_alive()] + if alive: + with lock: + results['hangs'].append(f'Round {round_num}: {len(alive)} threads stuck') + return False + + if results['hangs']: + return False + + with lock: + results['rounds_ok'] += 1 + return True + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--rounds', type=int, default=3) + parser.add_argument('--steps', type=int, default=10) + parser.add_argument('--churn', type=int, default=5) + args = parser.parse_args() + + log('=' * 60) + log('Multi-client Race Condition STRESS Test') + log(f'TWINKLE_FAIL_FAST = {os.getenv("TWINKLE_FAIL_FAST", "not set")}') + log(f'Rounds={args.rounds}, Steps/round={args.steps}, Adapter churn={args.churn}') + log('=' * 60) + + t_start = time.time() + for r in range(1, args.rounds + 1): + ok = run_round(r, steps_per_round=args.steps, adapter_churn=args.churn) + if not ok: + break + + total = time.time() - t_start + log('') + log('=' * 60) + log(f'FINAL RESULTS ({total:.1f}s total)') + log('=' * 60) + log(f' Rounds OK: {results["rounds_ok"]}/{args.rounds}') + log(f' Steps OK: {results["steps_ok"]}') + log(f' Errors: {len(results["errors"])}') + for e in results['errors'][:5]: + log(f' - {e}') + log(f' Hangs: {results["hangs"]}') + + if results['hangs']: + log('') + log('*** NCCL HANG DETECTED ***') + return 1 + log('') + log('ALL ROUNDS PASSED - no hang detected.') + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/test_nccl_safe.py b/tests/server/test_nccl_safe.py index eaa3d6cec..1bd852a1e 100644 --- a/tests/server/test_nccl_safe.py +++ b/tests/server/test_nccl_safe.py @@ -22,6 +22,7 @@ _is_fail_fast, _zero_loss, nccl_safe, + nccl_safe_megatron, safe_loss, ) @@ -650,3 +651,310 @@ def test_og_transparent(self, _dev_mode): assert og.loss_instance._nccl_safe_wrapped is True with pytest.raises(RuntimeError, match='Simulated'): og.loss_instance({}, {'logps': torch.tensor([1.0])}) + + +# ═════════════════════════════════════════════════════════════════════════ +# 10. Unit Tests: @nccl_safe_megatron decorator +# ═════════════════════════════════════════════════════════════════════════ + + +class TestNcclSafeMegatron: + + # ── Dev mode: transparent ── + + def test_transparent_in_dev_mode(self, _dev_mode): + @nccl_safe_megatron + def method(self, *, inputs, **kwargs): + raise ValueError('should propagate') + + with pytest.raises(ValueError, match='should propagate'): + method(MagicMock(), inputs=[]) + + def test_transparent_tinker_in_dev_mode(self, _dev_mode): + @nccl_safe_megatron(tinker=True) + def method(self, *, inputs, **kwargs): + raise RuntimeError('dev error') + + with pytest.raises(RuntimeError, match='dev error'): + method(MagicMock(), inputs=[]) + + def test_transparent_forward_only_in_dev_mode(self, _dev_mode): + @nccl_safe_megatron(forward_only=True) + def method(self, *, inputs, **kwargs): + raise RuntimeError('dev error') + + with pytest.raises(RuntimeError, match='dev error'): + method(MagicMock(), inputs=[]) + + # ── Production mode: catch all exceptions ── + + def test_normal_call_passes(self): + @nccl_safe_megatron + def method(self, *, inputs, **kwargs): + return {'loss': 1.5, 'logps': 'data'} + + result = method(MagicMock(), inputs=[]) + assert result == {'loss': 1.5, 'logps': 'data'} + + def test_exception_returns_fallback_dict(self): + @nccl_safe_megatron + def method(self, *, inputs, **kwargs): + raise RuntimeError('Megatron internal error') + + result = method(MagicMock(), inputs=[]) + assert result == {'loss': 0.0} + + def test_tinker_exception_returns_list(self): + @nccl_safe_megatron(tinker=True) + def method(self, *, inputs, **kwargs): + raise ValueError('data preprocessing failed') + + result = method(MagicMock(), inputs=[]) + assert result == [[], 0.0] + + def test_forward_only_exception_returns_empty_dict(self): + @nccl_safe_megatron(forward_only=True) + def method(self, *, inputs, **kwargs): + raise AssertionError('invalid inputs') + + result = method(MagicMock(), inputs=[]) + assert result == {} + + def test_consecutive_errors_all_caught(self): + call_count = [0] + + @nccl_safe_megatron + def method(self, *, inputs, **kwargs): + call_count[0] += 1 + raise RuntimeError(f'error #{call_count[0]}') + + model = MagicMock() + for _ in range(5): + result = method(model, inputs=[]) + assert result == {'loss': 0.0} + assert call_count[0] == 5 + + def test_error_then_normal(self): + call_count = [0] + + @nccl_safe_megatron + def method(self, *, inputs, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError('first call fails') + return {'loss': 2.0} + + model = MagicMock() + r1 = method(model, inputs=[]) + assert r1 == {'loss': 0.0} + + r2 = method(model, inputs=[]) + assert r2 == {'loss': 2.0} + + def test_keyboard_interrupt_propagates(self): + """KeyboardInterrupt is BaseException, NOT caught.""" + @nccl_safe_megatron + def method(self, *, inputs, **kwargs): + raise KeyboardInterrupt() + + with pytest.raises(KeyboardInterrupt): + method(MagicMock(), inputs=[]) + + def test_system_exit_propagates(self): + @nccl_safe_megatron + def method(self, *, inputs, **kwargs): + raise SystemExit(1) + + with pytest.raises(SystemExit): + method(MagicMock(), inputs=[]) + + # ── DPO-specific scenarios ── + + def test_dpo_forward_only_data_error_caught(self): + """DPO reference forward: data preprocessing error is caught.""" + @nccl_safe_megatron(forward_only=True) + def forward_only(self, *, inputs, **kwargs): + # Simulate template.batch_encode failure + raise AssertionError('Use set_template to add a template') + + result = forward_only(MagicMock(), inputs=[]) + assert result == {} + + def test_dpo_forward_backward_assertion_caught(self): + """DPO training forward_backward: assertion error is caught.""" + @nccl_safe_megatron + def forward_backward(self, *, inputs, **kwargs): + # Simulate batch_size assertion failure + raise AssertionError('Batch size must be even (chosen + rejected pairs)') + + result = forward_backward(MagicMock(), inputs=[]) + assert result == {'loss': 0.0} + + def test_dpo_tinker_ref_logps_mismatch_caught(self): + """DPO via Tinker: ref_logps mismatch is caught.""" + @nccl_safe_megatron(tinker=True) + def tinker_forward_backward(self, *, inputs, **kwargs): + raise ValueError('Cannot align ref_logps shape') + + result = tinker_forward_backward(MagicMock(), inputs=[]) + assert result == [[], 0.0] + + +# ═════════════════════════════════════════════════════════════════════════ +# 11. Multi-adapter concurrent scenarios +# ═════════════════════════════════════════════════════════════════════════ + + +class TestMultiAdapterConcurrency: + """Verify nccl_safe handles multi-adapter scenarios correctly.""" + + def test_different_adapters_independent_state(self): + """Two adapters with independent state: one failing shouldn't corrupt the other.""" + @nccl_safe + def method(self, *, inputs, adapter_name, **kwargs): + og = self.optimizer_group[adapter_name] + og.train_status.outputs = {'logps': torch.tensor([1.0], requires_grad=True)} + if adapter_name == 'adapter_bad': + og.train_status.loss_value = torch.tensor(1.0) + raise RuntimeError('adapter_bad fails after forward') + og.train_status.loss_value = None # backward done + return {'loss': 0.5} + + # Setup model with two adapters + model = MagicMock() + ts_good = TrainStatus() + ts_bad = TrainStatus() + og_good = MagicMock() + og_good.train_status = ts_good + og_bad = MagicMock() + og_bad.train_status = ts_bad + model.optimizer_group = {'adapter_good': og_good, 'adapter_bad': og_bad} + model.backward = MagicMock() + model.model = MagicMock() + model.model.parameters = MagicMock( + return_value=iter([torch.randn(3, requires_grad=True)])) + + # adapter_good should succeed normally + result_good = method(model, inputs=[], adapter_name='adapter_good') + assert result_good == {'loss': 0.5} + + # adapter_bad should be caught and force backward + result_bad = method(model, inputs=[], adapter_name='adapter_bad') + assert result_bad['loss'] == 0.0 + model.backward.assert_called_once() + + def test_sequential_adapter_failures_isolated(self): + """Sequential failures on different adapters don't accumulate state.""" + call_count = [0] + + @nccl_safe + def method(self, *, inputs, adapter_name, **kwargs): + call_count[0] += 1 + og = self.optimizer_group[adapter_name] + og.train_status.outputs = {'logps': torch.tensor([float(call_count[0])], requires_grad=True)} + og.train_status.loss_value = None # backward done + raise RuntimeError(f'post-backward error #{call_count[0]}') + + model = MagicMock() + for name in ['a1', 'a2', 'a3']: + ts = TrainStatus() + og = MagicMock() + og.train_status = ts + model.optimizer_group = {name: og} + model.backward = MagicMock() + result = method(model, inputs=[], adapter_name=name) + assert result['loss'] == 0.0 + # backward should NOT be called (post-backward error) + model.backward.assert_not_called() + + def test_megatron_multi_adapter_all_caught(self): + """nccl_safe_megatron catches errors for any adapter.""" + @nccl_safe_megatron + def method(self, *, inputs, adapter_name, **kwargs): + if adapter_name == 'bad': + raise RuntimeError('bad adapter data') + return {'loss': 1.0} + + model = MagicMock() + assert method(model, inputs=[], adapter_name='good') == {'loss': 1.0} + assert method(model, inputs=[], adapter_name='bad') == {'loss': 0.0} + # After error, good adapter still works + assert method(model, inputs=[], adapter_name='good') == {'loss': 1.0} + + +# ═════════════════════════════════════════════════════════════════════════ +# 12. Megatron communication timeout simulation +# ═════════════════════════════════════════════════════════════════════════ + + +class TestMegatronTimeoutSimulation: + """Simulate Megatron internal communication failures.""" + + def test_nccl_timeout_exception_caught(self): + """NCCL timeout RuntimeError inside Megatron is caught.""" + @nccl_safe_megatron + def forward_backward(self, *, inputs, **kwargs): + # Simulate NCCL timeout + raise RuntimeError( + 'Watchdog caught collective operation timeout: ' + 'WorkNCCL(SeqNum=42, OpType=ALLREDUCE) ran for 300000 milliseconds') + + result = forward_backward(MagicMock(), inputs=[]) + assert result == {'loss': 0.0} + + def test_nccl_timeout_in_forward_only(self): + """NCCL timeout during forward_only (reference model) is caught.""" + @nccl_safe_megatron(forward_only=True) + def forward_only(self, *, inputs, **kwargs): + raise RuntimeError( + 'NCCL communicator was aborted on rank 1. Original reason: ' + 'ProcessGroupNCCL abort') + + result = forward_only(MagicMock(), inputs=[]) + assert result == {} + + def test_cuda_oom_in_megatron_caught(self): + """CUDA OOM during Megatron forward is caught.""" + @nccl_safe_megatron + def forward_backward(self, *, inputs, **kwargs): + raise RuntimeError('CUDA out of memory. Tried to allocate 2.00 GiB') + + result = forward_backward(MagicMock(), inputs=[]) + assert result == {'loss': 0.0} + + def test_recovery_after_timeout(self): + """System recovers after a simulated timeout.""" + call_count = [0] + + @nccl_safe_megatron + def forward_backward(self, *, inputs, **kwargs): + call_count[0] += 1 + if call_count[0] <= 2: + raise RuntimeError('NCCL timeout') + return {'loss': 1.5} + + model = MagicMock() + # First two calls timeout + assert forward_backward(model, inputs=[]) == {'loss': 0.0} + assert forward_backward(model, inputs=[]) == {'loss': 0.0} + # Third call succeeds + assert forward_backward(model, inputs=[]) == {'loss': 1.5} + + def test_megatron_transparent_in_fail_fast(self, _dev_mode): + """In dev mode (TWINKLE_FAIL_FAST=1), exceptions propagate normally.""" + @nccl_safe_megatron + def forward_backward(self, *, inputs, **kwargs): + raise RuntimeError('would cause NCCL hang in production') + + # In dev mode, exception propagates + with pytest.raises(RuntimeError, match='would cause NCCL hang'): + forward_backward(MagicMock(), inputs=[]) + + def test_base_exception_still_propagates_in_dev_mode(self, _dev_mode): + """BaseException (KeyboardInterrupt, SystemExit) always propagates.""" + @nccl_safe_megatron + def method(self, *, inputs, **kwargs): + raise KeyboardInterrupt() + + with pytest.raises(KeyboardInterrupt): + method(MagicMock(), inputs=[]) From 5d67de92203164836651e65d29301a097f334c84 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 25 Jun 2026 19:17:49 +0800 Subject: [PATCH 08/16] update test scripte --- .../client/server/megatron/server_config.yaml | 2 +- .../config/server_config_4b_dpo_megatron.yaml | 90 +++++++++ .../config/server_config_4b_e2e_megatron.yaml | 75 +++---- ...rver_config_4b_e2e_megatron_failfast1.yaml | 75 +++---- tests/server/integration/test_dpo_pp_e2e.py | 191 ++++++++++++++++++ tests/server/start_e2e_server.py | 18 +- 6 files changed, 367 insertions(+), 84 deletions(-) create mode 100644 tests/server/config/server_config_4b_dpo_megatron.yaml create mode 100644 tests/server/integration/test_dpo_pp_e2e.py diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 1f47c2120..67ef5fdaa 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -107,7 +107,7 @@ applications: import_path: model args: backend: megatron # Use Megatron-LM backend - model_id: "ms://Qwen/Qwen3.6-27B" # ModelScope model identifier + model_id: "ms://Qwen/Qwen3.6-27B" # ModelScope model identifier max_length: 32768 # model max length max_loras: 3 # model max loras nproc_per_node: 4 # Number of GPU processes per node diff --git a/tests/server/config/server_config_4b_dpo_megatron.yaml b/tests/server/config/server_config_4b_dpo_megatron.yaml new file mode 100644 index 000000000..1fc274901 --- /dev/null +++ b/tests/server/config/server_config_4b_dpo_megatron.yaml @@ -0,0 +1,90 @@ +# Twinkle Server Configuration - DPO E2E Test (4B model, Megatron PP=2, no sampler) +# Minimal config for reproducing PP deadlock during DPO forward_backward. + +proxy_location: EveryNode + +http_options: + host: 0.0.0.0 + port: 9000 + +applications: + + - name: server + route_prefix: /api/v1 + import_path: server + args: + server_config: + per_token_model_limit: 30 + supported_models: + - Qwen/Qwen3.5-4B + deployments: + - name: TinkerCompatServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" + + - name: models-Qwen3.5-4B + route_prefix: /api/v1/model/Qwen/Qwen3.5-4B + import_path: model + args: + backend: megatron + model_id: "ms://Qwen/Qwen3.5-4B" + max_length: 10240 + nproc_per_node: 4 + device_group: + name: model + ranks: 4 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 + pp_size: 2 + queue_config: + rps_limit: 100 + tps_limit: 100000 + adapter_config: + adapter_timeout: 30 + adapter_max_lifetime: 36000 + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "0" + + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" diff --git a/tests/server/config/server_config_4b_e2e_megatron.yaml b/tests/server/config/server_config_4b_e2e_megatron.yaml index 10a6b5ce3..fc330b58e 100644 --- a/tests/server/config/server_config_4b_e2e_megatron.yaml +++ b/tests/server/config/server_config_4b_e2e_megatron.yaml @@ -36,14 +36,15 @@ applications: backend: megatron model_id: "ms://Qwen/Qwen3.5-4B" max_length: 10240 - nproc_per_node: 2 + nproc_per_node: 4 device_group: name: model - ranks: 2 + ranks: 4 device_type: cuda device_mesh: device_type: cuda dp_size: 2 + pp_size: 2 queue_config: rps_limit: 100 tps_limit: 100000 @@ -63,41 +64,41 @@ applications: TWINKLE_TRUST_REMOTE_CODE: "1" TWINKLE_FAIL_FAST: "0" - - name: sampler-Qwen3.5-4B - route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - import_path: sampler - args: - model_id: "ms://Qwen/Qwen3.5-4B" - nproc_per_node: 1 - sampler_type: vllm - engine_args: - max_model_len: 4096 - gpu_memory_utilization: 0.5 - enable_lora: true - max_loras: 5 - logprobs_mode: processed_logprobs - device_group: - name: sampler - ranks: 1 - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 - tps_limit: 100000 - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "1" - TWINKLE_FAIL_FAST: "0" + # - name: sampler-Qwen3.5-4B + # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen3.5-4B" + # nproc_per_node: 1 + # sampler_type: vllm + # engine_args: + # max_model_len: 4096 + # gpu_memory_utilization: 0.5 + # enable_lora: true + # max_loras: 5 + # logprobs_mode: processed_logprobs + # device_group: + # name: sampler + # ranks: 1 + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 + # tps_limit: 100000 + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "1" + # TWINKLE_FAIL_FAST: "0" - name: processor route_prefix: /api/v1/processor diff --git a/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml index d8103b938..0762bcfc4 100644 --- a/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml +++ b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml @@ -36,14 +36,15 @@ applications: backend: megatron model_id: "ms://Qwen/Qwen3.5-4B" max_length: 10240 - nproc_per_node: 2 + nproc_per_node: 4 device_group: name: model - ranks: 2 + ranks: 4 device_type: cuda device_mesh: device_type: cuda dp_size: 2 + pp_size: 2 queue_config: rps_limit: 100 tps_limit: 100000 @@ -63,41 +64,41 @@ applications: TWINKLE_TRUST_REMOTE_CODE: "1" TWINKLE_FAIL_FAST: "1" - - name: sampler-Qwen3.5-4B - route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - import_path: sampler - args: - model_id: "ms://Qwen/Qwen3.5-4B" - nproc_per_node: 1 - sampler_type: vllm - engine_args: - max_model_len: 4096 - gpu_memory_utilization: 0.5 - enable_lora: true - max_loras: 5 - logprobs_mode: processed_logprobs - device_group: - name: sampler - ranks: 1 - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 - tps_limit: 100000 - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "1" - TWINKLE_FAIL_FAST: "1" + # - name: sampler-Qwen3.5-4B + # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen3.5-4B" + # nproc_per_node: 1 + # sampler_type: vllm + # engine_args: + # max_model_len: 4096 + # gpu_memory_utilization: 0.5 + # enable_lora: true + # max_loras: 5 + # logprobs_mode: processed_logprobs + # device_group: + # name: sampler + # ranks: 1 + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 + # tps_limit: 100000 + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "1" + # TWINKLE_FAIL_FAST: "1" - name: processor route_prefix: /api/v1/processor diff --git a/tests/server/integration/test_dpo_pp_e2e.py b/tests/server/integration/test_dpo_pp_e2e.py new file mode 100644 index 000000000..425410e5e --- /dev/null +++ b/tests/server/integration/test_dpo_pp_e2e.py @@ -0,0 +1,191 @@ +"""DPO training E2E test on Megatron PP=2 backend. + +Reproduces the PP deadlock where forward_only succeeds but forward_backward +hangs due to nccl_safe loss skip breaking pipeline P2P communication. + +Flow (mirrors cookbook/client/twinkle/modelscope/dpo.py): + Phase 1 — Setup: configure model with DPO loss + Phase 2 — forward_only (ref_outputs): base model inference, no LoRA + Phase 3 — forward_backward (DPO training): triggers PP P2P communication + +No sampler required. Server config: server_config_4b_dpo_megatron.yaml + +## How to run + + # 1. Start server (no sampler, 4 GPU model only) + python tests/server/start_e2e_server.py \\ + --config tests/server/config/server_config_4b_dpo_megatron.yaml + + # 2. Run DPO PP test + TWINKLE_TEST_GPU_E2E=1 python -u tests/server/integration/test_dpo_pp_e2e.py + +Expected: forward_only succeeds, forward_backward either succeeds or +reproduces the PP deadlock (504 timeout / NCCL hang). +""" +from __future__ import annotations + +import dotenv + +dotenv.load_dotenv('.env') + +import os # noqa: E402 +import sys # noqa: E402 +import time # noqa: E402 +from typing import Any, Dict, List # noqa: E402 + +import numpy as np # noqa: E402 +import pytest # noqa: E402 +import torch # noqa: E402 +from peft import LoraConfig # noqa: E402 + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +from twinkle import get_logger, init_twinkle_client # noqa: E402 +from twinkle.dataloader import DataLoader # noqa: E402 +from twinkle.dataset import Dataset, DatasetMeta # noqa: E402 +from twinkle.preprocessor import EmojiDPOProcessor # noqa: E402 +from twinkle_client.model import MultiLoraTransformersModel # noqa: E402 + +logger = get_logger() + +# ── Configuration ── +BASE_MODEL = 'Qwen/Qwen3.5-4B' +BASE_URL = 'http://localhost:9000' +API_KEY = 'EMPTY_API_KEY' +DATASET_ID = 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' + +BATCH_SIZE = 4 +GRADIENT_ACCUMULATION_STEPS = 2 +DPO_BETA = 0.1 +SFT_WEIGHT = 1.0 +LOSS_TYPE = 'sigmoid' +MAX_LENGTH = 2048 +SYSTEM_PROMPT = 'You are a helpful assistant.' +DPO_TRAIN_STEPS = 4 # small number, just enough to trigger the bug +FORWARD_BACKWARD_TIMEOUT = 120 # seconds to wait before declaring hang + + +def _create_dpo_dataset() -> Dataset: + """Create DPO dataset with positive/negative format (small slice for speed).""" + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(100))) + dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=MAX_LENGTH) + dataset.map(EmojiDPOProcessor, init_args={'system': SYSTEM_PROMPT}) + dataset.encode() + return dataset + + +def _prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Interleave positive/negative pairs: [pos_1, neg_1, pos_2, neg_2, ...]. + + This DP-safe interleaving ensures each DP worker gets complete pairs + after slicing. + """ + result = [] + for row in batch: + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + pos_sample = {**base_fields, **row['positive']} + neg_sample = {**base_fields, **row['negative']} + result.append(pos_sample) + result.append(neg_sample) + return result + + +def _convert_tensors(batch: List[Dict[str, Any]]) -> None: + """Convert numpy/torch tensors to lists for serialization (in-place).""" + for row in batch: + for key in row: + if isinstance(row[key], np.ndarray): + row[key] = row[key].tolist() + elif isinstance(row[key], torch.Tensor): + row[key] = row[key].cpu().numpy().tolist() + + +def _configure_dpo_model() -> MultiLoraTransformersModel: + """Configure model with DPO loss, optimizer, and LoRA adapter.""" + model = MultiLoraTransformersModel(model_id=f'ms://{BASE_MODEL}') + model.add_adapter_to_model( + 'default', + LoraConfig(target_modules='all-linear', r=8, lora_alpha=32, lora_dropout=0.05), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + model.set_loss('DPOLoss', beta=DPO_BETA, loss_type=LOSS_TYPE, reference_free=False, sft_weight=SFT_WEIGHT) + model.add_metric('DPOMetric', beta=DPO_BETA) + model.set_optimizer('Adam', lr=1e-4) + return model + + +def main() -> int: + client = init_twinkle_client(base_url=BASE_URL, api_key=API_KEY) + logger.info('Available models:') + for m in client.get_server_capabilities().supported_models: + logger.info(f' - {m.model_name}') + + # ── Phase 1: Setup ── + logger.info('=' * 60) + logger.info('Phase 1: Setup — configure DPO model + load dataset') + logger.info('=' * 60) + + dataset = _create_dpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + model = _configure_dpo_model() + + logger.info('Model and dataset ready. Starting DPO training loop...') + logger.info(f'DPO config: beta={DPO_BETA}, loss_type={LOSS_TYPE}, sft_weight={SFT_WEIGHT}') + logger.info(f'Training {DPO_TRAIN_STEPS} steps with GA={GRADIENT_ACCUMULATION_STEPS}') + + # ── Phase 2+3: DPO training loop ── + logger.info('=' * 60) + logger.info('Phase 2+3: DPO training (forward_only + forward_backward)') + logger.info('=' * 60) + + step = 0 + for batch in dataloader: + _convert_tensors(batch) + dpo_batch = _prepare_dpo_batch(batch) + + # Phase 2: forward_only — get reference outputs (base model, no LoRA) + logger.info(f'[Step {step + 1}] forward_only (ref_outputs) ...') + t0 = time.time() + ref_outputs = model.forward_only(inputs=dpo_batch, disable_lora=True) + t_fo = time.time() - t0 + logger.info(f'[Step {step + 1}] forward_only OK ({t_fo:.1f}s)') + + # Phase 3: forward_backward — DPO training with ref_outputs + logger.info(f'[Step {step + 1}] forward_backward (DPO loss) ...') + t0 = time.time() + model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs.result) + t_fb = time.time() - t0 + logger.info(f'[Step {step + 1}] forward_backward OK ({t_fb:.1f}s)') + + model.clip_grad_and_step() + + # Log metrics every GA steps + step += 1 + if step % GRADIENT_ACCUMULATION_STEPS == 0: + metrics = model.calculate_metric(is_training=True) + logger.info(f'[Optim step {step // GRADIENT_ACCUMULATION_STEPS}] {metrics}') + + if step >= DPO_TRAIN_STEPS: + break + + logger.info('=' * 60) + logger.info('ALL DPO PHASES PASSED') + logger.info('=' * 60) + return 0 + + +# ── pytest entry point ── + +def test_dpo_pp_e2e(): + """Pytest-collected entry point for the DPO PP E2E suite.""" + rc = main() + assert rc == 0, 'DPO PP E2E test failed' + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/start_e2e_server.py b/tests/server/start_e2e_server.py index dcf6dd3f4..dc2a9f694 100644 --- a/tests/server/start_e2e_server.py +++ b/tests/server/start_e2e_server.py @@ -25,7 +25,7 @@ # ── Server check ── SERVER_URL = "http://localhost:9000/-/routes" READY_KEYWORD = "processor" -TIMEOUT = 180 +TIMEOUT = 600 POLL_INTERVAL = 5 @@ -62,25 +62,25 @@ def kill_server(): def restart_ray(): - """Stop and restart the Ray cluster (3 GPU + 1 CPU nodes).""" + """Stop and restart the Ray cluster (5 GPU + 1 CPU nodes).""" print("[2/4] Restarting Ray cluster...") run(f"{RAY} stop --force", check=False) time.sleep(2) - # Head node: GPU 0,1 - run(f"{RAY} start --head --port=6379 --num-gpus=2 " + # Head node: GPU 0,1,2,3 (4 GPUs for model PP=2 x DP=2) + run(f"{RAY} start --head --port=6379 --num-gpus=4 " f"--disable-usage-stats --temp-dir={RAY_TEMP_DIR}", - env={"CUDA_VISIBLE_DEVICES": "0,1"}) + env={"CUDA_VISIBLE_DEVICES": "0,1,2,3"}) - # Worker: GPU 2 + # Worker: GPU 4 (1 GPU for sampler) run(f"{RAY} start --address=127.0.0.1:6379 --num-gpus=1", - env={"CUDA_VISIBLE_DEVICES": "2"}) + env={"CUDA_VISIBLE_DEVICES": "4"}) - # CPU-only worker + # CPU-only worker (processor + server) run(f"{RAY} start --address=127.0.0.1:6379 --num-gpus=0", env={"CUDA_VISIBLE_DEVICES": ""}) - print(" Ray cluster started (2+1+0 GPUs)") + print(" Ray cluster started (4+1+0 GPUs)") def launch_server(config: str): From c21d406404256b136fb962488e038dd90fa5e97c Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 25 Jun 2026 21:07:27 +0800 Subject: [PATCH 09/16] fix: handle ragged logps in DPO forward_backward and tinker_forward_only for PP>1 When variable_seq_lengths=True with Pipeline Parallelism, logps from different DP ranks / microbatches have different seq_lens. After HTTP serialization + collect_tensor_dict, they become ragged nested lists that torch.as_tensor cannot handle. Changes: - megatron_model.py: Flatten and pad_and_stack ragged ref_outputs logps in forward_backward before passing to MegatronModel (twinkle client) - common.py: Handle ragged list[list] in _tensor_output_to_rows by flattening microbatch nesting and pad_and_stack (tinker client) - nccl_safe.py: Add traceback to safe_loss and nccl_safe_megatron error logging for better diagnostics - Add tinker client DPO PP E2E test --- src/twinkle/server/model/backends/common.py | 12 +- .../server/model/backends/megatron_model.py | 13 ++ src/twinkle/utils/nccl_safe.py | 9 +- .../integration/test_dpo_tinker_pp_e2e.py | 149 ++++++++++++++++++ 4 files changed, 179 insertions(+), 4 deletions(-) create mode 100644 tests/server/integration/test_dpo_tinker_pp_e2e.py diff --git a/src/twinkle/server/model/backends/common.py b/src/twinkle/server/model/backends/common.py index 2ac4f3eec..2e5b4a783 100644 --- a/src/twinkle/server/model/backends/common.py +++ b/src/twinkle/server/model/backends/common.py @@ -204,7 +204,17 @@ def _tensor_output_to_rows(value, seq_lens: list[int], *, kind: str) -> list[tor # Non-last PP stages can legitimately produce no logits/logps. return None elif not (isinstance(value, list) and all(isinstance(item, torch.Tensor) for item in value)): - tensors = [torch.as_tensor(value, dtype=torch.float32)] + # Handle ragged list[list] (e.g. logps after to_cpu_safe_output + # converted variable-length Tensors to nested Python lists). + # Flatten microbatch grouping → per-sample 1D lists → pad_and_stack. + if isinstance(value, list) and value and isinstance(value[0], (list, tuple)): + flat = [s for item in value for s in (item if isinstance(item[0], (list, tuple)) else [item])] + from twinkle.utils import pad_and_stack_tensors + tensors = [pad_and_stack_tensors( + [torch.tensor(s, dtype=torch.float32) for s in flat], + pad_value=0.0, concat=False)] + else: + tensors = [torch.as_tensor(value, dtype=torch.float32)] tensors = [tensor.detach().cpu() for tensor in tensors] if len(tensors) == len(seq_lens) and all(tensor.dim() <= 1 or tensor.shape[0] == 1 for tensor in tensors): diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 635d981ba..1e73b97ae 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -110,6 +110,19 @@ def forward_only(self, *, inputs: InputFeature | list[InputFeature] | Trajectory @nccl_safe_megatron def forward_backward(self, *, inputs: InputFeature | list[InputFeature] | Trajectory | list[Trajectory], **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" + # Normalize ragged ref_outputs logps into a regular 2D tensor. + # After HTTP + collect_tensor_dict, logps is a nested list grouped + # by microbatch with varying seq_lens across DP ranks. Flatten to + # per-sample 1D lists and pad_and_stack — same as datum.py L84-88. + ref_outputs = kwargs.get('ref_outputs') + if isinstance(ref_outputs, dict) and 'logps' in ref_outputs: + logps = ref_outputs['logps'] + if isinstance(logps, (list, tuple)) and logps and not isinstance(logps[0], torch.Tensor): + # Flatten [[mb0_sample0, mb0_sample1], [mb1_sample0, ...]] → [sample0, sample1, ...] + flat = [s for item in logps for s in (item if isinstance(item[0], (list, tuple)) else [item])] + from twinkle.utils import pad_and_stack_tensors + ref_outputs['logps'] = pad_and_stack_tensors( + [torch.tensor(s, dtype=torch.float32) for s in flat], pad_value=0.0, concat=False) output = super().forward_backward(inputs=inputs, **kwargs) return to_cpu_safe_output(output) diff --git a/src/twinkle/utils/nccl_safe.py b/src/twinkle/utils/nccl_safe.py index 5311ee793..265b6a7dd 100644 --- a/src/twinkle/utils/nccl_safe.py +++ b/src/twinkle/utils/nccl_safe.py @@ -86,8 +86,9 @@ def __call__(self, inputs, outputs, **kwargs): try: return self._loss_instance(inputs, outputs, **kwargs) except Exception as e: - logger.warning(f'[nccl_safe] Loss computation skipped due to error: ' - f'{type(e).__name__}: {e}') + import traceback + logger.warning('[nccl_safe] Loss computation skipped due to error: ' + '%s: %s\n%s', type(e).__name__, e, traceback.format_exc()) return _zero_loss(outputs) @@ -308,8 +309,10 @@ def wrapper(self, *args, **kwargs): try: return fn(self, *args, **kwargs) except Exception as e: + import traceback logger.warning(f'[nccl_safe_megatron] Exception in Megatron method ' - f'{fn.__name__}: {type(e).__name__}: {e}') + f'{fn.__name__}: {type(e).__name__}: {e}\n' + f'{traceback.format_exc()}') # Return safe fallback to prevent NCCL hang on other ranks if tinker: diff --git a/tests/server/integration/test_dpo_tinker_pp_e2e.py b/tests/server/integration/test_dpo_tinker_pp_e2e.py new file mode 100644 index 000000000..ca872be56 --- /dev/null +++ b/tests/server/integration/test_dpo_tinker_pp_e2e.py @@ -0,0 +1,149 @@ +"""Tinker client DPO test on Megatron PP=2 backend. + +Reproduces the tinker_forward_only ragged logps error. + +Usage: + # 1. Start server (PP=2, no sampler) + python tests/server/start_e2e_server.py \ + --config tests/server/config/server_config_4b_dpo_megatron.yaml + + # 2. Run test + TWINKLE_TEST_GPU_E2E=1 python -u tests/server/integration/test_dpo_tinker_pp_e2e.py +""" +from __future__ import annotations + +import os +import sys +import time + +import numpy as np +import torch + +SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +BASE_MODEL = 'Qwen/Qwen3.5-4B' + + +def log(msg): + ts = time.strftime('%Y-%m-%d %H:%M:%S') + print(f'[{ts}][INFO:twinkle] {msg}', flush=True) + + +def main(): + if not os.environ.get('TWINKLE_TEST_GPU_E2E'): + log('SKIP: set TWINKLE_TEST_GPU_E2E=1 to run') + return 0 + + # ── Init tinker client ──────────────────────────────────────────── + from tinker import types + from twinkle import init_tinker_client + init_tinker_client() + from tinker import ServiceClient + + service_client = ServiceClient(base_url=SERVER_URL, api_key='EMPTY_TOKEN') + training_client = service_client.create_lora_training_client( + base_model=BASE_MODEL, rank=8, + ) + log(f'Tinker training client ready (model={BASE_MODEL})') + + # ── Prepare DPO dataset ─────────────────────────────────────────── + from twinkle.dataset import Dataset, DatasetMeta + from twinkle.preprocessor import EmojiDPOProcessor + from twinkle.server.common import input_feature_to_datum + + log('Loading DPO dataset (10 samples)...') + dataset = Dataset(DatasetMeta(f'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji', data_slice=range(10))) + dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=1024) + dataset.map(EmojiDPOProcessor, init_args={'system': 'You are a helpful assistant.'}) + dataset.encode() + + # Build interleaved [pos, neg, pos, neg] batch + batch = list(dataset)[:4] + dpo_batch = [] + for row in batch: + for key in list(row.keys()): + if isinstance(row[key], np.ndarray): + row[key] = row[key].tolist() + elif isinstance(row[key], torch.Tensor): + row[key] = row[key].cpu().numpy().tolist() + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + dpo_batch.append({**base_fields, **row['positive']}) + dpo_batch.append({**base_fields, **row['negative']}) + log(f'DPO batch: {len(dpo_batch)} samples (interleaved pos/neg)') + + # Convert to Tinker Datums + input_datums = [input_feature_to_datum(row) for row in dpo_batch] + seq_lens = [d.loss_fn_inputs['target_tokens'].to_numpy().shape[0] for d in input_datums] + log(f'Datum seq_lens: {seq_lens}') + + # ── Step 1: Reference forward (tinker_forward_only) ─────────────── + log('=' * 60) + log('Step 1: tinker forward (reference, disable_lora=True)...') + log('=' * 60) + start = time.time() + try: + ref_result = training_client.forward( + input_datums, 'cross_entropy', + loss_fn_config={'disable_lora': True}, + ).result() + elapsed = time.time() - start + log(f'Step 1 OK ({elapsed:.1f}s), {len(ref_result.loss_fn_outputs)} outputs') + + # Show logprobs shapes + for i, out in enumerate(ref_result.loss_fn_outputs): + lp = out.get('logprobs') + if lp is not None: + arr = np.array(lp.tolist()) + log(f' output[{i}] logprobs shape={arr.shape}') + except Exception as e: + elapsed = time.time() - start + log(f'Step 1 FAILED ({elapsed:.1f}s): {type(e).__name__}: {e}') + if elapsed > 120: + log('TIMEOUT — likely NCCL hang!') + log('Check server log for traceback') + return 1 + + # ── Step 2: Attach ref_logps to datums ──────────────────────────── + log('Step 2: Attaching ref_logps to datums...') + for datum, ref_out in zip(input_datums, ref_result.loss_fn_outputs): + ref_logprobs_np = np.array(ref_out['logprobs'].tolist(), dtype=np.float32) + datum.loss_fn_inputs['ref_logps'] = types.TensorData.from_numpy(ref_logprobs_np) + log(f' ref_logps shape={ref_logprobs_np.shape}') + + # ── Step 3: DPO forward_backward (tinker_forward_backward) ──────── + log('=' * 60) + log('Step 3: tinker forward_backward (DPO loss)...') + log('=' * 60) + start = time.time() + try: + fwdbwd_result = training_client.forward_backward( + input_datums, 'importance_sampling', + loss_fn_config={'dpo_beta': 0.1, 'dpo_sft_weight': 1.0}, + ).result() + elapsed = time.time() - start + log(f'Step 3 OK ({elapsed:.1f}s)') + except Exception as e: + elapsed = time.time() - start + log(f'Step 3 FAILED ({elapsed:.1f}s): {type(e).__name__}: {e}') + if elapsed > 120: + log('TIMEOUT — likely NCCL hang!') + return 1 + + # ── Step 4: Optimizer step ──────────────────────────────────────── + log('Step 4: optim_step...') + try: + optim_result = training_client.optim_step( + types.AdamParams(learning_rate=1e-4) + ).result() + log(f'Step 4 OK, metrics={optim_result.metrics}') + except Exception as e: + log(f'Step 4 FAILED: {e}') + return 1 + + log('=' * 60) + log('ALL TINKER DPO PHASES PASSED') + log('=' * 60) + return 0 + + +if __name__ == '__main__': + sys.exit(main()) From 72ffacf8685b2c7c230305c9d0861e64eee0bc1d Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 26 Jun 2026 09:57:27 +0800 Subject: [PATCH 10/16] fix lint --- cookbook/client/tinker/modelscope/dpo.py | 2 +- src/twinkle/server/gateway/twinkle_handlers.py | 3 +-- src/twinkle/server/model/backends/common.py | 8 +++++--- src/twinkle/server/model/backends/megatron_model.py | 5 +++-- src/twinkle/utils/nccl_safe.py | 3 ++- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/cookbook/client/tinker/modelscope/dpo.py b/cookbook/client/tinker/modelscope/dpo.py index 23cf5aaee..cfbe916e9 100644 --- a/cookbook/client/tinker/modelscope/dpo.py +++ b/cookbook/client/tinker/modelscope/dpo.py @@ -51,7 +51,7 @@ max_length = 2048 lora_rank = 8 system_prompt = 'You are a helpful assistant.' -use_swanlab = True +use_swanlab = False # --------------------------------------------------------------------------- diff --git a/src/twinkle/server/gateway/twinkle_handlers.py b/src/twinkle/server/gateway/twinkle_handlers.py index 107c00c53..c3de4d8a5 100644 --- a/src/twinkle/server/gateway/twinkle_handlers.py +++ b/src/twinkle/server/gateway/twinkle_handlers.py @@ -54,8 +54,7 @@ async def healthz_deep( for model in self.supported_models: model_name = model.model_name try: - resp = await self.proxy.proxy_request( - request, 'healthz', model_name, 'model') + resp = await self.proxy.proxy_request(request, 'healthz', model_name, 'model') healthy = (resp.status_code == 200) if not healthy: all_healthy = False diff --git a/src/twinkle/server/model/backends/common.py b/src/twinkle/server/model/backends/common.py index 2e5b4a783..7c90a4ed0 100644 --- a/src/twinkle/server/model/backends/common.py +++ b/src/twinkle/server/model/backends/common.py @@ -210,9 +210,11 @@ def _tensor_output_to_rows(value, seq_lens: list[int], *, kind: str) -> list[tor if isinstance(value, list) and value and isinstance(value[0], (list, tuple)): flat = [s for item in value for s in (item if isinstance(item[0], (list, tuple)) else [item])] from twinkle.utils import pad_and_stack_tensors - tensors = [pad_and_stack_tensors( - [torch.tensor(s, dtype=torch.float32) for s in flat], - pad_value=0.0, concat=False)] + tensors = [ + pad_and_stack_tensors([torch.tensor(s, dtype=torch.float32) for s in flat], + pad_value=0.0, + concat=False) + ] else: tensors = [torch.as_tensor(value, dtype=torch.float32)] diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 1e73b97ae..fd42a29b3 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -121,8 +121,9 @@ def forward_backward(self, *, inputs: InputFeature | list[InputFeature] | Trajec # Flatten [[mb0_sample0, mb0_sample1], [mb1_sample0, ...]] → [sample0, sample1, ...] flat = [s for item in logps for s in (item if isinstance(item[0], (list, tuple)) else [item])] from twinkle.utils import pad_and_stack_tensors - ref_outputs['logps'] = pad_and_stack_tensors( - [torch.tensor(s, dtype=torch.float32) for s in flat], pad_value=0.0, concat=False) + ref_outputs['logps'] = pad_and_stack_tensors([torch.tensor(s, dtype=torch.float32) for s in flat], + pad_value=0.0, + concat=False) output = super().forward_backward(inputs=inputs, **kwargs) return to_cpu_safe_output(output) diff --git a/src/twinkle/utils/nccl_safe.py b/src/twinkle/utils/nccl_safe.py index 265b6a7dd..d012548b6 100644 --- a/src/twinkle/utils/nccl_safe.py +++ b/src/twinkle/utils/nccl_safe.py @@ -88,7 +88,8 @@ def __call__(self, inputs, outputs, **kwargs): except Exception as e: import traceback logger.warning('[nccl_safe] Loss computation skipped due to error: ' - '%s: %s\n%s', type(e).__name__, e, traceback.format_exc()) + '%s: %s\n%s', + type(e).__name__, e, traceback.format_exc()) return _zero_loss(outputs) From 40e0a8d12c2932c5e99c7e5eb94793ceae289fd3 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 26 Jun 2026 12:18:43 +0800 Subject: [PATCH 11/16] fix review --- cookbook/client/server/megatron/entrypoint.sh | 2 +- src/twinkle/utils/nccl_safe.py | 6 +- .../config/server_config_4b_dpo_megatron.yaml | 90 -- .../config/server_config_4b_e2e_megatron.yaml | 72 +- ...rver_config_4b_e2e_megatron_failfast1.yaml | 123 +++ tests/server/integration/openai_e2e.py | 196 ---- .../server/integration/test_race_nccl_hang.py | 251 ----- tests/server/test_nccl_safe.py | 960 ------------------ 8 files changed, 162 insertions(+), 1538 deletions(-) delete mode 100644 tests/server/config/server_config_4b_dpo_megatron.yaml delete mode 100644 tests/server/integration/openai_e2e.py delete mode 100644 tests/server/integration/test_race_nccl_hang.py delete mode 100644 tests/server/test_nccl_safe.py diff --git a/cookbook/client/server/megatron/entrypoint.sh b/cookbook/client/server/megatron/entrypoint.sh index 6f55d4e68..879e3890c 100755 --- a/cookbook/client/server/megatron/entrypoint.sh +++ b/cookbook/client/server/megatron/entrypoint.sh @@ -82,7 +82,7 @@ check_http_health() { fi if command -v wget &> /dev/null; then - wget -q --spider --timeout=10 "$url" + wget -q -O /dev/null --timeout=10 "$url" return fi diff --git a/src/twinkle/utils/nccl_safe.py b/src/twinkle/utils/nccl_safe.py index d012548b6..8b9740d45 100644 --- a/src/twinkle/utils/nccl_safe.py +++ b/src/twinkle/utils/nccl_safe.py @@ -241,10 +241,10 @@ def _force_zero_backward(model, og, adapter_name, kwargs): # Fallback: use first model parameter to maintain graph connectivity. # Do NOT detach() the parameter -- the zero loss must remain connected # to the model's autograd graph so FSDP ReduceScatter hooks fire. + # Use lazy iteration to avoid materializing the full parameter list. try: - params = [p for p in _iter_model_params(model) if p.requires_grad] - if params: - param = params[0] + param = next((p for p in _iter_model_params(model) if p.requires_grad), None) + if param is not None: zero_loss = (param.flatten()[0] * 0).sum() else: zero_loss = torch.zeros((), device='cuda', requires_grad=True) diff --git a/tests/server/config/server_config_4b_dpo_megatron.yaml b/tests/server/config/server_config_4b_dpo_megatron.yaml deleted file mode 100644 index 1fc274901..000000000 --- a/tests/server/config/server_config_4b_dpo_megatron.yaml +++ /dev/null @@ -1,90 +0,0 @@ -# Twinkle Server Configuration - DPO E2E Test (4B model, Megatron PP=2, no sampler) -# Minimal config for reproducing PP deadlock during DPO forward_backward. - -proxy_location: EveryNode - -http_options: - host: 0.0.0.0 - port: 9000 - -applications: - - - name: server - route_prefix: /api/v1 - import_path: server - args: - server_config: - per_token_model_limit: 30 - supported_models: - - Qwen/Qwen3.5-4B - deployments: - - name: TinkerCompatServer - max_ongoing_requests: 50 - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_FAIL_FAST: "0" - - - name: models-Qwen3.5-4B - route_prefix: /api/v1/model/Qwen/Qwen3.5-4B - import_path: model - args: - backend: megatron - model_id: "ms://Qwen/Qwen3.5-4B" - max_length: 10240 - nproc_per_node: 4 - device_group: - name: model - ranks: 4 - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 2 - pp_size: 2 - queue_config: - rps_limit: 100 - tps_limit: 100000 - adapter_config: - adapter_timeout: 30 - adapter_max_lifetime: 36000 - deployments: - - name: ModelManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "1" - TWINKLE_FAIL_FAST: "0" - - - name: processor - route_prefix: /api/v1/processor - import_path: processor - args: - ncpu_proc_per_node: 2 - device_group: - name: model - ranks: 2 - device_type: CPU - device_mesh: - device_type: CPU - dp_size: 2 - deployments: - - name: ProcessorManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_FAIL_FAST: "0" diff --git a/tests/server/config/server_config_4b_e2e_megatron.yaml b/tests/server/config/server_config_4b_e2e_megatron.yaml index fc330b58e..a6f193501 100644 --- a/tests/server/config/server_config_4b_e2e_megatron.yaml +++ b/tests/server/config/server_config_4b_e2e_megatron.yaml @@ -1,4 +1,4 @@ -# Twinkle Server Configuration - E2E Test (4B model, Megatron backend) +# Twinkle Server Configuration - E2E Test (4B model, Megatron DP=2 PP=2) proxy_location: EveryNode @@ -50,7 +50,6 @@ applications: tps_limit: 100000 adapter_config: adapter_timeout: 30 - adapter_max_lifetime: 36000 deployments: - name: ModelManagement autoscaling_config: @@ -64,41 +63,40 @@ applications: TWINKLE_TRUST_REMOTE_CODE: "1" TWINKLE_FAIL_FAST: "0" - # - name: sampler-Qwen3.5-4B - # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen3.5-4B" - # nproc_per_node: 1 - # sampler_type: vllm - # engine_args: - # max_model_len: 4096 - # gpu_memory_utilization: 0.5 - # enable_lora: true - # max_loras: 5 - # logprobs_mode: processed_logprobs - # device_group: - # name: sampler - # ranks: 1 - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 - # tps_limit: 100000 - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "1" - # TWINKLE_FAIL_FAST: "0" + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" + nproc_per_node: 1 + sampler_type: vllm + engine_args: + max_model_len: 4096 + gpu_memory_utilization: 0.5 + enable_lora: true + logprobs_mode: processed_logprobs + device_group: + name: sampler + ranks: 1 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "0" - name: processor route_prefix: /api/v1/processor diff --git a/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml index 0762bcfc4..848f4755d 100644 --- a/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml +++ b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml @@ -1,3 +1,126 @@ +# Twinkle Server Configuration - E2E Test (4B model, Megatron DP=2 PP=2, FAIL_FAST=1) + +proxy_location: EveryNode + +http_options: + host: 0.0.0.0 + port: 9000 + +applications: + + - name: server + route_prefix: /api/v1 + import_path: server + args: + server_config: + per_token_model_limit: 20 + supported_models: + - Qwen/Qwen3.5-4B + deployments: + - name: TinkerCompatServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" + + - name: models-Qwen3.5-4B + route_prefix: /api/v1/model/Qwen/Qwen3.5-4B + import_path: model + args: + backend: megatron + model_id: "ms://Qwen/Qwen3.5-4B" + max_length: 10240 + nproc_per_node: 4 + device_group: + name: model + ranks: 4 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 + pp_size: 2 + queue_config: + rps_limit: 100 + tps_limit: 100000 + adapter_config: + adapter_timeout: 30 + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" + + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" + nproc_per_node: 1 + sampler_type: vllm + engine_args: + max_model_len: 4096 + gpu_memory_utilization: 0.5 + enable_lora: true + logprobs_mode: processed_logprobs + device_group: + name: sampler + ranks: 1 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" + + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" # Twinkle Server Configuration - E2E Test (4B model, Megatron backend) proxy_location: EveryNode diff --git a/tests/server/integration/openai_e2e.py b/tests/server/integration/openai_e2e.py deleted file mode 100644 index 546798ca2..000000000 --- a/tests/server/integration/openai_e2e.py +++ /dev/null @@ -1,196 +0,0 @@ -# OpenAI-compatible endpoint E2E test -# -# Requires: server with vLLM sampler running (server_e2e.py config). -# Tests both non-streaming and streaming /v1/chat/completions via OpenAI SDK. -# -# Usage: -# python -u tests/server/integration/openai_e2e.py - -import sys -import time -from openai import OpenAI - -BASE_URL = 'http://127.0.0.1:8000/api/v1' -API_KEY = 'EMPTY_API_KEY' -MODEL = 'Qwen/Qwen3.5-4B' - - -def test_list_models(client: OpenAI): - print('--- Step 1: GET /models ---') - models = client.models.list() - model_ids = [m.id for m in models.data] - print(f' Available models: {model_ids}') - assert MODEL in model_ids, f'{MODEL} not in {model_ids}' - print(' PASS\n') - - -def test_non_streaming(client: OpenAI): - print('--- Step 2: Non-streaming chat completion ---') - t0 = time.time() - resp = client.chat.completions.create( - model=MODEL, - messages=[ - { - 'role': 'system', - 'content': 'You are a helpful assistant.' - }, - { - 'role': 'user', - 'content': 'What is 2+2? Answer in one word.' - }, - ], - max_tokens=32, - temperature=0.1, - ) - elapsed = time.time() - t0 - print(f' Model: {resp.model}') - print(f' Choices: {len(resp.choices)}') - content = resp.choices[0].message.content - print(f' Content: {content!r}') - print(f' Finish reason: {resp.choices[0].finish_reason}') - print(f' Usage: prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}') - print(f' Elapsed: {elapsed:.2f}s') - assert content and len(content) > 0, 'Empty response' - assert resp.choices[0].finish_reason in ('stop', 'length') - print(' PASS\n') - - -def test_streaming(client: OpenAI): - print('--- Step 3: Streaming chat completion ---') - t0 = time.time() - stream = client.chat.completions.create( - model=MODEL, - messages=[ - { - 'role': 'user', - 'content': 'Count from 1 to 5.' - }, - ], - max_tokens=64, - temperature=0.1, - stream=True, - ) - - chunks = [] - full_content = '' - for chunk in stream: - chunks.append(chunk) - delta = chunk.choices[0].delta - if delta.content: - full_content += delta.content - print(f' chunk: {delta.content[:60]!r}...' if len(delta.content or '') > - 60 else f' chunk: {delta.content!r}') - - elapsed = time.time() - t0 - print(f' Total chunks: {len(chunks)}') - print(f' Full content length: {len(full_content)} chars') - print(f' Elapsed: {elapsed:.2f}s') - assert len(chunks) >= 1, 'Expected at least one chunk' - assert full_content and len(full_content) > 0, 'Empty streamed response' - # Find the last chunk with a finish_reason - last_finish = None - for c in reversed(chunks): - if c.choices[0].finish_reason: - last_finish = c.choices[0].finish_reason - break - print(f' Finish reason: {last_finish}') - assert last_finish in ('stop', 'length') - print(' PASS\n') - - -def test_multi_turn(client: OpenAI): - print('--- Step 4: Multi-turn conversation ---') - resp = client.chat.completions.create( - model=MODEL, - messages=[ - { - 'role': 'system', - 'content': 'You are a math tutor.' - }, - { - 'role': 'user', - 'content': 'What is 3*7?' - }, - { - 'role': 'assistant', - 'content': '3*7 = 21' - }, - { - 'role': 'user', - 'content': 'Now add 4 to that.' - }, - ], - max_tokens=32, - temperature=0.1, - ) - content = resp.choices[0].message.content - print(f' Content: {content!r}') - assert content and len(content) > 0, 'Empty response' - print(' PASS\n') - - -def test_sticky_session(base_url: str): - """Verify sticky session routing with X-Twinkle-Replica-Id header. - - Sends multiple requests with the same model key and asserts they all - report the same replica ID. This proves Serve-Multiplexed-Model-Id - correctly pins requests to a single replica. - - Requires the sampler to have the inject_replica_id middleware - (added in sampler/app.py). - """ - import httpx - - print('--- Step 5: Sticky session verification (replica-id) ---') - - replica_ids = [] - for i in range(5): - resp = httpx.post( - f'{base_url}/chat/completions', - json={ - 'model': MODEL, - 'messages': [{ - 'role': 'user', - 'content': f'Say {i}.' - }], - 'max_tokens': 5, - 'temperature': 0.1, - }, - headers={'Authorization': 'Bearer EMPTY_API_KEY'}, - timeout=60, - ) - assert resp.status_code == 200, f'Request {i} failed: {resp.status_code}' - rid = resp.headers.get('x-twinkle-replica-id') - replica_ids.append(rid) - print(f' Request {i}: replica_id={rid}') - - # All requests with same model must hit the same replica - unique = set(replica_ids) - assert None not in unique, 'X-Twinkle-Replica-Id header missing from responses' - assert len(unique) == 1, (f'Sticky session broken: requests routed to {len(unique)} different replicas: {unique}') - print(f' All 5 requests → replica {replica_ids[0]}') - print(' PASS\n') - - -def main(): - client = OpenAI(base_url=BASE_URL, api_key=API_KEY) - - test_list_models(client) - test_non_streaming(client) - test_streaming(client) - test_multi_turn(client) - test_sticky_session(BASE_URL) - - print('=' * 50) - print('ALL STEPS PASSED') - print('=' * 50) - - -if __name__ == '__main__': - try: - main() - except Exception as e: - print(f'\nFAILED: {e}', file=sys.stderr) - import traceback - traceback.print_exc() - sys.exit(1) diff --git a/tests/server/integration/test_race_nccl_hang.py b/tests/server/integration/test_race_nccl_hang.py deleted file mode 100644 index 3072a7aae..000000000 --- a/tests/server/integration/test_race_nccl_hang.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Multi-client race condition stress test: reproduce ncclCommSplit hang. - -Strategy: 3 concurrent clients hammer the Megatron server in a loop: -- Client A: continuous DPO training (forward_only + forward_backward + step) -- Client B: continuous DPO training with occasional bad data (triggers errors) -- Client C: repeatedly creates/destroys adapters (triggers ncclCommSplit) - -Runs multiple rounds. If any operation takes >TIMEOUT seconds, NCCL hang is detected. - -Usage: - python tests/server/integration/test_race_nccl_hang.py - python tests/server/integration/test_race_nccl_hang.py --rounds 5 -""" -import argparse -import os -import sys -import threading -import time - -os.environ['TINKER_BASE_URL'] = 'http://localhost:9000' -os.environ['TWINKLE_SERVER_TOKEN'] = 'EMPTY_TOKEN' - -from twinkle_client import init_twinkle_client -from twinkle_client.model import MultiLoraTransformersModel -from peft import LoraConfig - -SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') -TIMEOUT = 90 -seq_len = 64 - - -def make_batch(size=4, include_position_ids=True): - batch = [] - for _ in range(size): - item = { - 'input_ids': list(range(1, seq_len + 1)), - 'labels': [-100] * 32 + list(range(100, 132)), - 'attention_mask': [1] * seq_len, - } - if include_position_ids: - item['position_ids'] = list(range(seq_len)) - batch.append(item) - return batch - - -results = {'hangs': [], 'errors': [], 'steps_ok': 0, 'rounds_ok': 0} -lock = threading.Lock() -stop_event = threading.Event() - - -def log(msg): - print(f'[RACE] {msg}', flush=True) - - -def create_session(name): - init_twinkle_client(base_url=SERVER_URL, api_key='EMPTY_TOKEN') - model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3.5-4B') - model.add_adapter_to_model( - adapter_name=name, - config=LoraConfig(r=16, target_modules=['q_proj', 'v_proj']), - gradient_accumulation_steps=1, - ) - model.set_loss('DPOLoss', init_args={'beta': 0.1}) - model.set_optimizer('Adam', lr=1e-5) - model.set_template('Qwen3_5Template') - model.set_processor('InputProcessor', padding_side='right') - return model - - -def client_a_training(steps_per_round): - """Client A: continuous normal DPO training.""" - try: - model = create_session('client-a') - log('Client-A: session ready') - batch = make_batch(4) - for i in range(steps_per_round): - if stop_event.is_set(): - return - start = time.time() - model.forward_only(inputs=batch, disable_lora=True) - model.forward_backward(inputs=batch) - model.clip_grad_and_step() - elapsed = time.time() - start - with lock: - results['steps_ok'] += 1 - if i % 3 == 0: - log(f'Client-A: step {i + 1}/{steps_per_round} ({elapsed:.1f}s)') - if elapsed > TIMEOUT: - with lock: - results['hangs'].append(f'Client-A step {i + 1} ({elapsed:.0f}s)') - stop_event.set() - return - except Exception as e: - with lock: - results['errors'].append(f'Client-A: {type(e).__name__}: {str(e)[:80]}') - log(f'Client-A: ERROR {type(e).__name__}: {str(e)[:80]}') - - -def client_b_mixed_training(steps_per_round): - """Client B: training with mix of good and bad requests.""" - try: - model = create_session('client-b') - log('Client-B: session ready') - good_batch = make_batch(4) - bad_batch_no_pos = make_batch(4, include_position_ids=False) # missing position_ids - bad_batch_odd = make_batch(3) # odd size for DPO - - for i in range(steps_per_round): - if stop_event.is_set(): - return - start = time.time() - try: - # Every 4th request: send bad data to trigger error - if i % 4 == 3: - model.forward_backward(inputs=bad_batch_no_pos) - elif i % 7 == 6: - model.forward_backward(inputs=bad_batch_odd) - else: - model.forward_only(inputs=good_batch, disable_lora=True) - model.forward_backward(inputs=good_batch) - model.clip_grad_and_step() - elapsed = time.time() - start - with lock: - results['steps_ok'] += 1 - if i % 3 == 0: - log(f'Client-B: step {i + 1}/{steps_per_round} ({elapsed:.1f}s)') - except Exception: - elapsed = time.time() - start - if elapsed > TIMEOUT: - with lock: - results['hangs'].append(f'Client-B step {i + 1} ({elapsed:.0f}s)') - stop_event.set() - return - # Expected errors from bad data - continue - if i % 5 == 0: - log(f'Client-B: step {i + 1} error (expected, {elapsed:.1f}s)') - except Exception as e: - with lock: - results['errors'].append(f'Client-B: {type(e).__name__}: {str(e)[:80]}') - log(f'Client-B: ERROR {type(e).__name__}: {str(e)[:80]}') - - -def client_c_adapter_churn(count): - """Client C: repeatedly create adapters (triggers ncclCommSplit).""" - time.sleep(0.5) # let training start first - for i in range(count): - if stop_event.is_set(): - return - name = f'churn-{i}' - start = time.time() - try: - init_twinkle_client(base_url=SERVER_URL, api_key='EMPTY_TOKEN') - m = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3.5-4B') - m.add_adapter_to_model( - adapter_name=name, - config=LoraConfig(r=16, target_modules=['q_proj', 'v_proj']), - gradient_accumulation_steps=1, - ) - elapsed = time.time() - start - if i % 2 == 0: - log(f'Client-C: adapter {name} OK ({elapsed:.1f}s)') - if elapsed > TIMEOUT: - with lock: - results['hangs'].append(f'Client-C {name} ({elapsed:.0f}s)') - stop_event.set() - return - except Exception as e: - elapsed = time.time() - start - if elapsed > TIMEOUT: - with lock: - results['hangs'].append(f'Client-C {name} ({elapsed:.0f}s)') - stop_event.set() - return - log(f'Client-C: {name} error ({elapsed:.1f}s) - continuing') - time.sleep(0.1) - - -def run_round(round_num, steps_per_round=10, adapter_churn=5): - """Run one round of the stress test.""" - global results - stop_event.clear() - - log(f'--- Round {round_num} (steps={steps_per_round}, churn={adapter_churn}) ---') - - t1 = threading.Thread(target=client_a_training, args=(steps_per_round,)) - t2 = threading.Thread(target=client_b_mixed_training, args=(steps_per_round,)) - t3 = threading.Thread(target=client_c_adapter_churn, args=(adapter_churn,)) - - threads = [t1, t2, t3] - for t in threads: - t.start() - - for t in threads: - t.join(timeout=300) - - alive = [t for t in threads if t.is_alive()] - if alive: - with lock: - results['hangs'].append(f'Round {round_num}: {len(alive)} threads stuck') - return False - - if results['hangs']: - return False - - with lock: - results['rounds_ok'] += 1 - return True - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--rounds', type=int, default=3) - parser.add_argument('--steps', type=int, default=10) - parser.add_argument('--churn', type=int, default=5) - args = parser.parse_args() - - log('=' * 60) - log('Multi-client Race Condition STRESS Test') - log(f'TWINKLE_FAIL_FAST = {os.getenv("TWINKLE_FAIL_FAST", "not set")}') - log(f'Rounds={args.rounds}, Steps/round={args.steps}, Adapter churn={args.churn}') - log('=' * 60) - - t_start = time.time() - for r in range(1, args.rounds + 1): - ok = run_round(r, steps_per_round=args.steps, adapter_churn=args.churn) - if not ok: - break - - total = time.time() - t_start - log('') - log('=' * 60) - log(f'FINAL RESULTS ({total:.1f}s total)') - log('=' * 60) - log(f' Rounds OK: {results["rounds_ok"]}/{args.rounds}') - log(f' Steps OK: {results["steps_ok"]}') - log(f' Errors: {len(results["errors"])}') - for e in results['errors'][:5]: - log(f' - {e}') - log(f' Hangs: {results["hangs"]}') - - if results['hangs']: - log('') - log('*** NCCL HANG DETECTED ***') - return 1 - log('') - log('ALL ROUNDS PASSED - no hang detected.') - return 0 - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/tests/server/test_nccl_safe.py b/tests/server/test_nccl_safe.py deleted file mode 100644 index 1bd852a1e..000000000 --- a/tests/server/test_nccl_safe.py +++ /dev/null @@ -1,960 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Tests for NCCL-safe fault tolerance utilities. - -Organized in five tiers: -1. Unit: _is_fail_fast(), safe_loss(), _zero_loss() -2. Unit: @nccl_safe decorator -3. Unit: BaseOptimizerGroup.__setattr__ auto-wrapping -4. Integration: real loss functions (GRPO, CrossEntropy) through safe_loss -5. Adversarial: malicious data injection, monkey-patching, stability -""" -import pytest -import torch - -from unittest.mock import MagicMock - -from twinkle.data_format import LossOutput -from twinkle.loss.base import Loss -from twinkle.loss.cross_entropy import CrossEntropyLoss -from twinkle.loss.grpo import GRPOLoss -from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus -from twinkle.utils.nccl_safe import ( - _is_fail_fast, - _zero_loss, - nccl_safe, - nccl_safe_megatron, - safe_loss, -) - - -# ─── Helpers ────────────────────────────────────────────────────────────── - - -class DummyLoss(Loss): - """Simple loss returning sum of logps.""" - require_logps = True - require_entropy = False - require_logits = False - - def __call__(self, inputs, outputs, **kwargs): - logps = outputs['logps'] - return LossOutput(loss=logps.sum(), num_tokens=logps.numel()) - - -class ExplodingLoss(Loss): - """Loss that always raises RuntimeError.""" - require_logps = True - require_entropy = True - require_logits = False - - def __call__(self, inputs, outputs, **kwargs): - raise RuntimeError('Simulated loss explosion') - - -# ─── Fixtures ───────────────────────────────────────────────────────────── - - -@pytest.fixture(autouse=True) -def _production_mode(monkeypatch): - """Set TWINKLE_FAIL_FAST=0 for all tests (production mode).""" - monkeypatch.setenv('TWINKLE_FAIL_FAST', '0') - - -@pytest.fixture -def _dev_mode(monkeypatch): - """Switch to development mode (TWINKLE_FAIL_FAST=1).""" - monkeypatch.setenv('TWINKLE_FAIL_FAST', '1') - - -# ═══════════════════════════════════════════════════════════════════════════ -# 1. Unit Tests: _is_fail_fast -# ═══════════════════════════════════════════════════════════════════════════ - - -class TestIsFailFast: - - def test_default_is_fail_fast(self, monkeypatch): - monkeypatch.delenv('TWINKLE_FAIL_FAST', raising=False) - assert _is_fail_fast() is True - - def test_explicit_1(self, monkeypatch): - monkeypatch.setenv('TWINKLE_FAIL_FAST', '1') - assert _is_fail_fast() is True - - def test_explicit_0(self, monkeypatch): - monkeypatch.setenv('TWINKLE_FAIL_FAST', '0') - assert _is_fail_fast() is False - - @pytest.mark.parametrize('val', ['no', 'false', 'off', 'NO', 'False', 'OFF']) - def test_falsy_strings(self, monkeypatch, val): - monkeypatch.setenv('TWINKLE_FAIL_FAST', val) - assert _is_fail_fast() is False - - -# ═══════════════════════════════════════════════════════════════════════════ -# 2. Unit Tests: safe_loss -# ═══════════════════════════════════════════════════════════════════════════ - - -class TestSafeLoss: - - def test_transparent_in_dev_mode(self, _dev_mode): - """In dev mode, safe_loss still wraps but wrapper propagates exceptions.""" - loss = ExplodingLoss() - wrapped = safe_loss(loss) - assert wrapped is not loss # always wrapped now - assert wrapped._nccl_safe_wrapped is True - with pytest.raises(RuntimeError, match='Simulated'): - wrapped({}, {'logps': torch.tensor([1.0])}) - - def test_wraps_in_production(self): - loss = DummyLoss() - wrapped = safe_loss(loss) - assert wrapped is not loss - assert callable(wrapped) - - def test_idempotent(self): - loss = DummyLoss() - w1 = safe_loss(loss) - w2 = safe_loss(w1) - assert w1 is w2 - - def test_forwards_attributes(self): - loss = DummyLoss() - w = safe_loss(loss) - assert w.require_logps is True - assert w.require_entropy is False - assert w.require_logits is False - assert w._nccl_safe_wrapped is True - - def test_wrapper_is_loss_subclass(self): - """Wrapped instance must satisfy isinstance(..., Loss) assertions.""" - loss = DummyLoss() - w = safe_loss(loss) - assert isinstance(w, Loss) - - def test_preserves_custom_entropy_flag(self): - loss = GRPOLoss(entropy_coef=0.1) - assert loss.require_entropy is True - w = safe_loss(loss) - assert w.require_entropy is True - - def test_normal_call_passes_through(self): - w = safe_loss(DummyLoss()) - result = w({}, {'logps': torch.tensor([1.0, 2.0, 3.0], requires_grad=True)}) - assert result['loss'].item() == 6.0 - - def test_exception_returns_zero_loss(self): - w = safe_loss(ExplodingLoss()) - result = w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) - assert result['loss'].item() == 0.0 - assert result['num_tokens'] == 0 - - def test_zero_loss_is_graph_connected(self): - w = safe_loss(ExplodingLoss()) - logps = torch.tensor([1.0, 2.0], requires_grad=True) - result = w({}, {'logps': logps}) - result['loss'].backward() - assert logps.grad is not None - - def test_backward_on_zero_loss_yields_zero_grad(self): - w = safe_loss(ExplodingLoss()) - logps = torch.randn(3, requires_grad=True) - result = w({}, {'logps': logps}) - result['loss'].backward() - assert (logps.grad == 0).all() - - -# ═══════════════════════════════════════════════════════════════════════════ -# 3. Unit Tests: _zero_loss -# ═══════════════════════════════════════════════════════════════════════════ - - -class TestZeroLoss: - - def test_from_logps(self): - logps = torch.tensor([1.0, 2.0], requires_grad=True) - r = _zero_loss({'logps': logps}) - assert r['loss'].item() == 0.0 - r['loss'].backward() - assert logps.grad is not None - - def test_from_logits(self): - logits = torch.randn(2, 3, requires_grad=True) - r = _zero_loss({'logits': logits}) - assert r['loss'].item() == 0.0 - r['loss'].backward() - assert logits.grad is not None - - def test_from_loss_key(self): - t = torch.tensor(1.0, requires_grad=True) - r = _zero_loss({'loss': t}) - assert r['loss'].item() == 0.0 - r['loss'].backward() - assert t.grad is not None - - def test_fallback_no_grad_tensor(self): - r = _zero_loss({'logps': torch.tensor([1.0])}) # no requires_grad - assert r['loss'].item() == 0.0 - assert r['loss'].requires_grad - - def test_empty_dict(self): - r = _zero_loss({}) - assert r['loss'].item() == 0.0 - - def test_non_dict(self): - r = _zero_loss('not_a_dict') - assert r['loss'].item() == 0.0 - - def test_num_tokens_zero(self): - r = _zero_loss({'logps': torch.tensor([1.0], requires_grad=True)}) - assert r['num_tokens'] == 0 - - def test_priority_order_logps_first(self): - """logps should be preferred over logits for graph connectivity.""" - logps = torch.tensor([1.0], requires_grad=True) - logits = torch.randn(1, 3, requires_grad=True) - r = _zero_loss({'logps': logps, 'logits': logits}) - r['loss'].backward() - assert logps.grad is not None - assert logits.grad is None - - -# ═══════════════════════════════════════════════════════════════════════════ -# 4. Unit Tests: @nccl_safe decorator -# ═══════════════════════════════════════════════════════════════════════════ - - -def _make_model(adapter_name='default', outputs=None, loss_value='sentinel'): - """Create a mock model with optimizer_group for decorator tests.""" - model = MagicMock() - ts = TrainStatus() - ts.outputs = outputs - ts.loss_value = loss_value - - og = MagicMock() - og.train_status = ts - model.optimizer_group = {adapter_name: og} - model._get_default_group = MagicMock(return_value=adapter_name) - return model, og - - -class TestNcclSafeDecorator: - - # ── Basic behaviour ── - - def test_transparent_in_dev_mode(self, _dev_mode): - @nccl_safe - def method(self, *, inputs, **kwargs): - return {'loss': 1.0} - - model, _ = _make_model() - assert method(model, inputs=[], adapter_name='default') == {'loss': 1.0} - - def test_normal_call_passes(self): - @nccl_safe - def method(self, *, inputs, **kwargs): - return {'loss': 1.0} - - model, _ = _make_model() - assert method(model, inputs=[], adapter_name='default') == {'loss': 1.0} - - # ── Pre-forward failure -> re-raise ── - - def test_pre_forward_error_propagates(self): - @nccl_safe - def method(self, *, inputs, **kwargs): - raise ValueError('pre-forward') - - model, _ = _make_model(outputs=None) - with pytest.raises(ValueError, match='pre-forward'): - method(model, inputs=[], adapter_name='default') - - # ── Post-forward, pre-backward -> force backward ── - - def test_post_forward_pre_backward_forces_backward(self): - outputs_after = {'logps': torch.tensor([1.0], requires_grad=True)} - - @nccl_safe - def method(self, *, inputs, **kwargs): - og = self.optimizer_group['default'] - og.train_status.outputs = outputs_after - og.train_status.loss_value = torch.tensor(1.0) - raise RuntimeError('mid-pipeline') - - model, _ = _make_model(outputs=None, loss_value=None) - model.backward = MagicMock() - model.model = MagicMock() - model.model.parameters = MagicMock( - return_value=iter([torch.randn(3, requires_grad=True)])) - - result = method(model, inputs=[], adapter_name='default') - model.backward.assert_called_once() - assert result['loss'] == 0.0 - - def test_post_forward_model_list_forces_backward(self): - """Fallback path must tolerate ``model.model`` being a list (Megatron multi-LoRA).""" - outputs_after = {'logps': torch.tensor([1.0], requires_grad=True)} - - @nccl_safe - def method(self, *, inputs, **kwargs): - og = self.optimizer_group['default'] - og.train_status.outputs = outputs_after - og.train_status.loss_value = torch.tensor(1.0) - raise RuntimeError('mid-pipeline') - - model, _ = _make_model(outputs=None, loss_value=None) - model.backward = MagicMock() - param1 = torch.randn(3, requires_grad=True) - param2 = torch.randn(3, requires_grad=True) - model.model = [MagicMock(), MagicMock()] - model.model[0].parameters = MagicMock(return_value=iter([param1])) - model.model[1].parameters = MagicMock(return_value=iter([param2])) - - result = method(model, inputs=[], adapter_name='default') - model.backward.assert_called_once() - assert result['loss'] == 0.0 - - # ── Post-backward failure -> no extra backward ── - - def test_post_backward_no_extra_backward(self): - @nccl_safe - def method(self, *, inputs, **kwargs): - og = self.optimizer_group['default'] - og.train_status.outputs = {'logps': torch.tensor([1.0])} - og.train_status.loss_value = None # backward done - raise RuntimeError('post-backward') - - model, _ = _make_model(outputs=None) - model.backward = MagicMock() - - result = method(model, inputs=[], adapter_name='default') - model.backward.assert_not_called() - assert result['loss'] == 0.0 - - # ── tinker mode ── - - def test_tinker_returns_list(self): - @nccl_safe(tinker=True) - def method(self, *, inputs, **kwargs): - og = self.optimizer_group['default'] - og.train_status.outputs = {'logps': torch.tensor([1.0])} - og.train_status.loss_value = None - raise RuntimeError('err') - - model, _ = _make_model(outputs=None) - assert method(model, inputs=[], adapter_name='default') == [[], 0.0] - - # ── No optimizer group -> passthrough ── - - def test_no_optimizer_group_raises(self): - @nccl_safe - def method(self, *, inputs, **kwargs): - raise ValueError('err') - - model = MagicMock() - model.optimizer_group = {} - model._get_default_group = MagicMock(return_value='default') - - with pytest.raises(ValueError, match='err'): - method(model, inputs=[], adapter_name='default') - - # ── gradient_accumulation_steps forwarded ── - - def test_gas_forwarded_to_backward(self): - @nccl_safe - def method(self, *, inputs, **kwargs): - og = self.optimizer_group['default'] - og.train_status.outputs = {'logps': torch.tensor([1.0], requires_grad=True)} - og.train_status.loss_value = torch.tensor(1.0) - raise RuntimeError('err') - - model, _ = _make_model(outputs=None, loss_value=None) - model.backward = MagicMock() - model.model = MagicMock() - model.model.parameters = MagicMock( - return_value=iter([torch.randn(3, requires_grad=True)])) - - method(model, inputs=[], adapter_name='default', gradient_accumulation_steps=4) - _, call_kwargs = model.backward.call_args - assert call_kwargs.get('gradient_accumulation_steps') == 4 - - -# ═══════════════════════════════════════════════════════════════════════════ -# 5. Unit Tests: BaseOptimizerGroup.__setattr__ -# ═══════════════════════════════════════════════════════════════════════════ - - -class TestOptimizerGroupSetattr: - - def test_auto_wraps_loss(self): - og = BaseOptimizerGroup() - loss = DummyLoss() - og.loss_instance = loss - assert og.loss_instance is not loss - assert og.loss_instance._nccl_safe_wrapped is True - - def test_does_not_wrap_none(self): - og = BaseOptimizerGroup() - og.loss_instance = None - assert og.loss_instance is None - - def test_other_attrs_unaffected(self): - og = BaseOptimizerGroup() - og.adapter_name = 'test' - assert og.adapter_name == 'test' - og.cur_step = 42 - assert og.cur_step == 42 - - def test_idempotent_via_setattr(self): - og = BaseOptimizerGroup() - og.loss_instance = DummyLoss() - first = og.loss_instance - og.loss_instance = first - assert og.loss_instance is first - - def test_transparent_in_dev_mode(self, _dev_mode): - """In dev mode, OG auto-wraps but wrapper propagates exceptions.""" - og = BaseOptimizerGroup() - loss = ExplodingLoss() - og.loss_instance = loss - assert og.loss_instance is not loss # always wrapped now - assert og.loss_instance._nccl_safe_wrapped is True - with pytest.raises(RuntimeError, match='Simulated'): - og.loss_instance({}, {'logps': torch.tensor([1.0])}) - - -# ═══════════════════════════════════════════════════════════════════════════ -# 6. Integration: real loss functions through safe_loss -# ═══════════════════════════════════════════════════════════════════════════ - - -def _grpo_fixtures(batch=2, seq_len=8): - """Create valid GRPO inputs/outputs/kwargs.""" - labels = torch.randint(0, 100, (batch, seq_len)) - labels[:, :3] = -100 - inputs = {'labels': labels} - logps = torch.randn(batch, seq_len, requires_grad=True) - outputs = {'logps': logps} - n_valid = (labels != -100).sum(dim=1).tolist() - old_logps = [torch.randn(n).tolist() for n in n_valid] - advantages = torch.randn(batch).tolist() - return inputs, outputs, old_logps, advantages - - -class TestIntegrationRealLosses: - - # ── GRPO ── - - def test_grpo_normal(self): - w = safe_loss(GRPOLoss(epsilon=0.2)) - inp, out, olp, adv = _grpo_fixtures() - r = w(inp, out, old_logps=olp, advantages=adv) - assert r['loss'].requires_grad - - def test_grpo_bad_old_logps_caught(self): - """old_logps length mismatch -> AssertionError -> caught.""" - w = safe_loss(GRPOLoss(epsilon=0.2)) - inp, out, _, adv = _grpo_fixtures() - r = w(inp, out, old_logps=[[0.1, 0.2]], advantages=adv) - assert r['loss'].item() == 0.0 - - def test_grpo_bad_old_logps_graph_connected(self): - """Zero loss from GRPO error should still be graph-connected.""" - w = safe_loss(GRPOLoss(epsilon=0.2)) - labels = torch.tensor([[0, 1, 2, 3, 4]]) - labels[0, :2] = -100 - logps = torch.randn(1, 5, requires_grad=True) - r = w({'labels': labels}, {'logps': logps}, - old_logps=[[0.1, 0.2]], advantages=[1.0]) - r['loss'].backward() - assert logps.grad is not None - - # ── CrossEntropy ── - - def test_cross_entropy_normal(self): - w = safe_loss(CrossEntropyLoss()) - labels = torch.randint(0, 100, (2, 8)) - labels[:, :3] = -100 - logps = torch.randn(2, 8, requires_grad=True) - r = w({'labels': labels}, {'logps': logps}) - assert r['loss'].requires_grad - - def test_cross_entropy_missing_labels_caught(self): - w = safe_loss(CrossEntropyLoss()) - r = w({}, {'logps': torch.randn(2, 8, requires_grad=True)}) - assert r['loss'].item() == 0.0 - - -# ═══════════════════════════════════════════════════════════════════════════ -# 7. Integration: full chain via OptimizerGroup -# ═══════════════════════════════════════════════════════════════════════════ - - -class TestEndToEndOptimizerGroup: - - def test_grpo_via_og(self): - og = BaseOptimizerGroup() - og.loss_instance = GRPOLoss(epsilon=0.2) - inp, out, olp, adv = _grpo_fixtures() - r = og.loss_instance(inp, out, old_logps=olp, advantages=adv) - assert 'loss' in r and r['loss'].requires_grad - - def test_grpo_bad_data_via_og(self): - og = BaseOptimizerGroup() - og.loss_instance = GRPOLoss(epsilon=0.2) - labels = torch.tensor([[0, 1, 2, 3, 4]]) - labels[0, :2] = -100 # 3 valid positions - logps = torch.randn(1, 5, requires_grad=True) - # 2 values for 3 valid positions -> AssertionError in _pad_and_align_to_batch - r = og.loss_instance({'labels': labels}, {'logps': logps}, - old_logps=[[0.1, 0.2]], advantages=[1.0]) - assert r['loss'].item() == 0.0 - - def test_replace_loss_auto_wraps(self): - og = BaseOptimizerGroup() - og.loss_instance = DummyLoss() - first = og.loss_instance - - og.loss_instance = ExplodingLoss() - second = og.loss_instance - - assert first is not second - assert second._nccl_safe_wrapped is True - r = second({}, {'logps': torch.tensor([1.0], requires_grad=True)}) - assert r['loss'].item() == 0.0 - - def test_ce_via_og(self): - og = BaseOptimizerGroup() - og.loss_instance = CrossEntropyLoss() - labels = torch.randint(0, 100, (2, 8)) - labels[:, :3] = -100 - r = og.loss_instance({'labels': labels}, - {'logps': torch.randn(2, 8, requires_grad=True)}) - assert 'loss' in r - - -# ═══════════════════════════════════════════════════════════════════════════ -# 8. Adversarial: monkey-patch injection -# ═══════════════════════════════════════════════════════════════════════════ - - -class TestAdversarial: - - def test_loss_raises_runtime_error(self): - """Loss that raises RuntimeError is caught.""" - class OOMLoss(Loss): - def __call__(self, inputs, outputs, **kwargs): - raise RuntimeError('GPU OOM') - - w = safe_loss(OOMLoss()) - r = w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) - assert r['loss'].item() == 0.0 - - def test_original_grpo_assertion_error(self): - """The original bug: AssertionError in _pad_and_align_to_batch.""" - w = safe_loss(GRPOLoss(epsilon=0.2)) - labels = torch.tensor([[0, 1, 2, 3, 4]]) - labels[0, :2] = -100 - logps = torch.randn(1, 5, requires_grad=True) - # 2 values but 3 valid positions -> AssertionError - r = w({'labels': labels}, {'logps': logps}, - old_logps=[[0.1, 0.2]], advantages=[1.0]) - assert r['loss'].item() == 0.0 - r['loss'].backward() - assert logps.grad is not None - - def test_nan_passes_through(self): - """NaN loss is NOT an exception -- should NOT be caught.""" - class NanLoss(Loss): - def __call__(self, inputs, outputs, **kw): - return LossOutput( - loss=torch.tensor(float('nan'), requires_grad=True), - num_tokens=1) - - w = safe_loss(NanLoss()) - r = w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) - assert torch.isnan(r['loss']) - - def test_consecutive_errors_all_caught(self): - w = safe_loss(ExplodingLoss()) - for _ in range(10): - r = w({}, {'logps': torch.randn(3, requires_grad=True)}) - assert r['loss'].item() == 0.0 - - def test_error_then_normal(self): - call_count = [0] - - class FlakeyLoss(Loss): - def __call__(self, inputs, outputs, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: - raise RuntimeError('first call fails') - return LossOutput(loss=outputs['logps'].sum(), num_tokens=1) - - w = safe_loss(FlakeyLoss()) - out1 = {'logps': torch.tensor([1.0, 2.0], requires_grad=True)} - r1 = w({}, out1) - assert r1['loss'].item() == 0.0 - - out2 = {'logps': torch.tensor([4.0, 5.0], requires_grad=True)} - r2 = w({}, out2) - assert r2['loss'].item() == 9.0 - - def test_keyboard_interrupt_propagates(self): - """KeyboardInterrupt is BaseException, NOT caught.""" - class KBLoss(Loss): - def __call__(self, *a, **kw): - raise KeyboardInterrupt() - - w = safe_loss(KBLoss()) - with pytest.raises(KeyboardInterrupt): - w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) - - def test_system_exit_propagates(self): - class ExitLoss(Loss): - def __call__(self, *a, **kw): - raise SystemExit(1) - - w = safe_loss(ExitLoss()) - with pytest.raises(SystemExit): - w({}, {'logps': torch.tensor([1.0], requires_grad=True)}) - - -# ═══════════════════════════════════════════════════════════════════════════ -# 9. Backward Compatibility (dev mode transparency) -# ═══════════════════════════════════════════════════════════════════════════ - - -class TestBackwardCompat: - - def test_safe_loss_transparent(self, _dev_mode): - loss = ExplodingLoss() - wrapped = safe_loss(loss) - assert wrapped is not loss # always wrapped, but transparent in dev mode - with pytest.raises(RuntimeError, match='Simulated'): - wrapped({}, {'logps': torch.tensor([1.0])}) - - def test_nccl_safe_transparent(self, _dev_mode): - @nccl_safe - def method(self, *, inputs, **kwargs): - raise ValueError('should propagate') - - with pytest.raises(ValueError, match='should propagate'): - method(MagicMock(), inputs=[]) - - def test_og_transparent(self, _dev_mode): - og = BaseOptimizerGroup() - loss = ExplodingLoss() - og.loss_instance = loss - assert og.loss_instance is not loss # always wrapped - assert og.loss_instance._nccl_safe_wrapped is True - with pytest.raises(RuntimeError, match='Simulated'): - og.loss_instance({}, {'logps': torch.tensor([1.0])}) - - -# ═════════════════════════════════════════════════════════════════════════ -# 10. Unit Tests: @nccl_safe_megatron decorator -# ═════════════════════════════════════════════════════════════════════════ - - -class TestNcclSafeMegatron: - - # ── Dev mode: transparent ── - - def test_transparent_in_dev_mode(self, _dev_mode): - @nccl_safe_megatron - def method(self, *, inputs, **kwargs): - raise ValueError('should propagate') - - with pytest.raises(ValueError, match='should propagate'): - method(MagicMock(), inputs=[]) - - def test_transparent_tinker_in_dev_mode(self, _dev_mode): - @nccl_safe_megatron(tinker=True) - def method(self, *, inputs, **kwargs): - raise RuntimeError('dev error') - - with pytest.raises(RuntimeError, match='dev error'): - method(MagicMock(), inputs=[]) - - def test_transparent_forward_only_in_dev_mode(self, _dev_mode): - @nccl_safe_megatron(forward_only=True) - def method(self, *, inputs, **kwargs): - raise RuntimeError('dev error') - - with pytest.raises(RuntimeError, match='dev error'): - method(MagicMock(), inputs=[]) - - # ── Production mode: catch all exceptions ── - - def test_normal_call_passes(self): - @nccl_safe_megatron - def method(self, *, inputs, **kwargs): - return {'loss': 1.5, 'logps': 'data'} - - result = method(MagicMock(), inputs=[]) - assert result == {'loss': 1.5, 'logps': 'data'} - - def test_exception_returns_fallback_dict(self): - @nccl_safe_megatron - def method(self, *, inputs, **kwargs): - raise RuntimeError('Megatron internal error') - - result = method(MagicMock(), inputs=[]) - assert result == {'loss': 0.0} - - def test_tinker_exception_returns_list(self): - @nccl_safe_megatron(tinker=True) - def method(self, *, inputs, **kwargs): - raise ValueError('data preprocessing failed') - - result = method(MagicMock(), inputs=[]) - assert result == [[], 0.0] - - def test_forward_only_exception_returns_empty_dict(self): - @nccl_safe_megatron(forward_only=True) - def method(self, *, inputs, **kwargs): - raise AssertionError('invalid inputs') - - result = method(MagicMock(), inputs=[]) - assert result == {} - - def test_consecutive_errors_all_caught(self): - call_count = [0] - - @nccl_safe_megatron - def method(self, *, inputs, **kwargs): - call_count[0] += 1 - raise RuntimeError(f'error #{call_count[0]}') - - model = MagicMock() - for _ in range(5): - result = method(model, inputs=[]) - assert result == {'loss': 0.0} - assert call_count[0] == 5 - - def test_error_then_normal(self): - call_count = [0] - - @nccl_safe_megatron - def method(self, *, inputs, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: - raise RuntimeError('first call fails') - return {'loss': 2.0} - - model = MagicMock() - r1 = method(model, inputs=[]) - assert r1 == {'loss': 0.0} - - r2 = method(model, inputs=[]) - assert r2 == {'loss': 2.0} - - def test_keyboard_interrupt_propagates(self): - """KeyboardInterrupt is BaseException, NOT caught.""" - @nccl_safe_megatron - def method(self, *, inputs, **kwargs): - raise KeyboardInterrupt() - - with pytest.raises(KeyboardInterrupt): - method(MagicMock(), inputs=[]) - - def test_system_exit_propagates(self): - @nccl_safe_megatron - def method(self, *, inputs, **kwargs): - raise SystemExit(1) - - with pytest.raises(SystemExit): - method(MagicMock(), inputs=[]) - - # ── DPO-specific scenarios ── - - def test_dpo_forward_only_data_error_caught(self): - """DPO reference forward: data preprocessing error is caught.""" - @nccl_safe_megatron(forward_only=True) - def forward_only(self, *, inputs, **kwargs): - # Simulate template.batch_encode failure - raise AssertionError('Use set_template to add a template') - - result = forward_only(MagicMock(), inputs=[]) - assert result == {} - - def test_dpo_forward_backward_assertion_caught(self): - """DPO training forward_backward: assertion error is caught.""" - @nccl_safe_megatron - def forward_backward(self, *, inputs, **kwargs): - # Simulate batch_size assertion failure - raise AssertionError('Batch size must be even (chosen + rejected pairs)') - - result = forward_backward(MagicMock(), inputs=[]) - assert result == {'loss': 0.0} - - def test_dpo_tinker_ref_logps_mismatch_caught(self): - """DPO via Tinker: ref_logps mismatch is caught.""" - @nccl_safe_megatron(tinker=True) - def tinker_forward_backward(self, *, inputs, **kwargs): - raise ValueError('Cannot align ref_logps shape') - - result = tinker_forward_backward(MagicMock(), inputs=[]) - assert result == [[], 0.0] - - -# ═════════════════════════════════════════════════════════════════════════ -# 11. Multi-adapter concurrent scenarios -# ═════════════════════════════════════════════════════════════════════════ - - -class TestMultiAdapterConcurrency: - """Verify nccl_safe handles multi-adapter scenarios correctly.""" - - def test_different_adapters_independent_state(self): - """Two adapters with independent state: one failing shouldn't corrupt the other.""" - @nccl_safe - def method(self, *, inputs, adapter_name, **kwargs): - og = self.optimizer_group[adapter_name] - og.train_status.outputs = {'logps': torch.tensor([1.0], requires_grad=True)} - if adapter_name == 'adapter_bad': - og.train_status.loss_value = torch.tensor(1.0) - raise RuntimeError('adapter_bad fails after forward') - og.train_status.loss_value = None # backward done - return {'loss': 0.5} - - # Setup model with two adapters - model = MagicMock() - ts_good = TrainStatus() - ts_bad = TrainStatus() - og_good = MagicMock() - og_good.train_status = ts_good - og_bad = MagicMock() - og_bad.train_status = ts_bad - model.optimizer_group = {'adapter_good': og_good, 'adapter_bad': og_bad} - model.backward = MagicMock() - model.model = MagicMock() - model.model.parameters = MagicMock( - return_value=iter([torch.randn(3, requires_grad=True)])) - - # adapter_good should succeed normally - result_good = method(model, inputs=[], adapter_name='adapter_good') - assert result_good == {'loss': 0.5} - - # adapter_bad should be caught and force backward - result_bad = method(model, inputs=[], adapter_name='adapter_bad') - assert result_bad['loss'] == 0.0 - model.backward.assert_called_once() - - def test_sequential_adapter_failures_isolated(self): - """Sequential failures on different adapters don't accumulate state.""" - call_count = [0] - - @nccl_safe - def method(self, *, inputs, adapter_name, **kwargs): - call_count[0] += 1 - og = self.optimizer_group[adapter_name] - og.train_status.outputs = {'logps': torch.tensor([float(call_count[0])], requires_grad=True)} - og.train_status.loss_value = None # backward done - raise RuntimeError(f'post-backward error #{call_count[0]}') - - model = MagicMock() - for name in ['a1', 'a2', 'a3']: - ts = TrainStatus() - og = MagicMock() - og.train_status = ts - model.optimizer_group = {name: og} - model.backward = MagicMock() - result = method(model, inputs=[], adapter_name=name) - assert result['loss'] == 0.0 - # backward should NOT be called (post-backward error) - model.backward.assert_not_called() - - def test_megatron_multi_adapter_all_caught(self): - """nccl_safe_megatron catches errors for any adapter.""" - @nccl_safe_megatron - def method(self, *, inputs, adapter_name, **kwargs): - if adapter_name == 'bad': - raise RuntimeError('bad adapter data') - return {'loss': 1.0} - - model = MagicMock() - assert method(model, inputs=[], adapter_name='good') == {'loss': 1.0} - assert method(model, inputs=[], adapter_name='bad') == {'loss': 0.0} - # After error, good adapter still works - assert method(model, inputs=[], adapter_name='good') == {'loss': 1.0} - - -# ═════════════════════════════════════════════════════════════════════════ -# 12. Megatron communication timeout simulation -# ═════════════════════════════════════════════════════════════════════════ - - -class TestMegatronTimeoutSimulation: - """Simulate Megatron internal communication failures.""" - - def test_nccl_timeout_exception_caught(self): - """NCCL timeout RuntimeError inside Megatron is caught.""" - @nccl_safe_megatron - def forward_backward(self, *, inputs, **kwargs): - # Simulate NCCL timeout - raise RuntimeError( - 'Watchdog caught collective operation timeout: ' - 'WorkNCCL(SeqNum=42, OpType=ALLREDUCE) ran for 300000 milliseconds') - - result = forward_backward(MagicMock(), inputs=[]) - assert result == {'loss': 0.0} - - def test_nccl_timeout_in_forward_only(self): - """NCCL timeout during forward_only (reference model) is caught.""" - @nccl_safe_megatron(forward_only=True) - def forward_only(self, *, inputs, **kwargs): - raise RuntimeError( - 'NCCL communicator was aborted on rank 1. Original reason: ' - 'ProcessGroupNCCL abort') - - result = forward_only(MagicMock(), inputs=[]) - assert result == {} - - def test_cuda_oom_in_megatron_caught(self): - """CUDA OOM during Megatron forward is caught.""" - @nccl_safe_megatron - def forward_backward(self, *, inputs, **kwargs): - raise RuntimeError('CUDA out of memory. Tried to allocate 2.00 GiB') - - result = forward_backward(MagicMock(), inputs=[]) - assert result == {'loss': 0.0} - - def test_recovery_after_timeout(self): - """System recovers after a simulated timeout.""" - call_count = [0] - - @nccl_safe_megatron - def forward_backward(self, *, inputs, **kwargs): - call_count[0] += 1 - if call_count[0] <= 2: - raise RuntimeError('NCCL timeout') - return {'loss': 1.5} - - model = MagicMock() - # First two calls timeout - assert forward_backward(model, inputs=[]) == {'loss': 0.0} - assert forward_backward(model, inputs=[]) == {'loss': 0.0} - # Third call succeeds - assert forward_backward(model, inputs=[]) == {'loss': 1.5} - - def test_megatron_transparent_in_fail_fast(self, _dev_mode): - """In dev mode (TWINKLE_FAIL_FAST=1), exceptions propagate normally.""" - @nccl_safe_megatron - def forward_backward(self, *, inputs, **kwargs): - raise RuntimeError('would cause NCCL hang in production') - - # In dev mode, exception propagates - with pytest.raises(RuntimeError, match='would cause NCCL hang'): - forward_backward(MagicMock(), inputs=[]) - - def test_base_exception_still_propagates_in_dev_mode(self, _dev_mode): - """BaseException (KeyboardInterrupt, SystemExit) always propagates.""" - @nccl_safe_megatron - def method(self, *, inputs, **kwargs): - raise KeyboardInterrupt() - - with pytest.raises(KeyboardInterrupt): - method(MagicMock(), inputs=[]) From 99b596c2e205e8610098c2fe46544e7520b500a0 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 26 Jun 2026 12:58:24 +0800 Subject: [PATCH 12/16] fix nccl hang --- src/twinkle/model/megatron/megatron.py | 103 ++++++++++-------- .../model/megatron/strategy/megatron.py | 4 +- .../integration/test_nccl_safe_tinker_e2e.py | 4 +- .../integration/test_nccl_safe_twinkle_e2e.py | 9 +- 4 files changed, 67 insertions(+), 53 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index ab5f936da..76e0c8dba 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -36,6 +36,7 @@ from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils import construct_class, get_logger, selective_log_softmax +from twinkle.utils.nccl_safe import _is_fail_fast from ._mindspeed_runtime import ensure_mindspeed_adaptor_patched from .strategy import MegatronStrategy @@ -385,14 +386,11 @@ def post_loss_function(output_tensor, inputs, logps, unpacked_logits=None, entro losses = result['loss'] counts = result['num_tokens'] if not counts: - # Later will gather this value, so it becomes: - # 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens - # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) - # = gradient_accumulation_steps * world_size - # Then, grad will divided by this value: - # 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens - # 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) - # / (gradient_accumulation_steps * world_size ) = avg_per_token_grad + # safe_loss returned zero loss (num_tokens=0): use output_tensor + # to rebuild graph connectivity so backward triggers ALL gradient + # buckets, preventing DP AllReduce asymmetry in PP mode. + if output_tensor.requires_grad: + losses = (output_tensor.flatten()[:1] * 0).sum() counts = torch.tensor(1, device=losses.device) return self.strategy.reduce_loss(losses, counts, output_tensor, logps) @@ -415,44 +413,59 @@ def forward_step_func(data_iterator, model): embeddings = None _loss_instance = loss_instance is_last_pp = mpu.is_pipeline_last_stage(False, unwrapped_model.vp_stage) - if task == 'embedding': - # MegatronEmbeddingPatch already pooled output to [n_seqs, hidden] on last PP stage. - if is_last_pp: - embeddings = output_tensor - elif labels is not None and is_last_pp: - _loss_require_logps = getattr(_loss_instance, 'require_logps', True) - _loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) - _packed = batch.get('packed_seq_params') - cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None - if _loss_require_logps: - loss_mask = (labels != -100).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - output_tensor.div_(temperature) - if _loss_require_entropy: - logps, entropies = selective_log_softmax(output_tensor, masked_labels, return_entropy=True) - else: - logps = selective_log_softmax(output_tensor, masked_labels) - # Reconstruct full-length tensors from CP-split shards - logps = processor.postprocess_tensor_cp(logps, cu_seqlens=cu_seqlens_q) + try: + if task == 'embedding': + # MegatronEmbeddingPatch already pooled output to [n_seqs, hidden] on last PP stage. + if is_last_pp: + embeddings = output_tensor + elif labels is not None and is_last_pp: + _loss_require_logps = getattr(_loss_instance, 'require_logps', True) + _loss_require_entropy = ( + hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) + _packed = batch.get('packed_seq_params') + cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None + if _loss_require_logps: + loss_mask = (labels != -100).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + output_tensor.div_(temperature) + if _loss_require_entropy: + logps, entropies = selective_log_softmax(output_tensor, masked_labels, return_entropy=True) + else: + logps = selective_log_softmax(output_tensor, masked_labels) + # Reconstruct full-length tensors from CP-split shards + logps = processor.postprocess_tensor_cp(logps, cu_seqlens=cu_seqlens_q) + if entropies is not None: + entropies = processor.postprocess_tensor_cp(entropies, cu_seqlens=cu_seqlens_q) + batch['labels'] = processor.postprocess_tensor_cp(labels, cu_seqlens=cu_seqlens_q) + if 'position_ids' in batch: + pos = batch['position_ids'] + if pos.dim() == 3: + pos = pos[0] # [2/3, 1, seq] → [1, seq] + batch['position_ids'] = processor.postprocess_tensor_cp(pos, cu_seqlens=cu_seqlens_q) + # Unpack packed sequences into per-sequence batch format + _outputs = {'logps': logps} if entropies is not None: - entropies = processor.postprocess_tensor_cp(entropies, cu_seqlens=cu_seqlens_q) - batch['labels'] = processor.postprocess_tensor_cp(labels, cu_seqlens=cu_seqlens_q) - if 'position_ids' in batch: - pos = batch['position_ids'] - if pos.dim() == 3: - pos = pos[0] # [2/3, 1, seq] → [1, seq] - batch['position_ids'] = processor.postprocess_tensor_cp(pos, cu_seqlens=cu_seqlens_q) - # Unpack packed sequences into per-sequence batch format - _outputs = {'logps': logps} - if entropies is not None: - _outputs['entropies'] = entropies - if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits: - _outputs['logits'] = output_tensor - batch, _outputs = processor.unpack_packed_sequences(batch, _outputs) - logps = _outputs['logps'] - entropies = _outputs.get('entropies', None) - unpacked_logits = _outputs.get('logits', None) + _outputs['entropies'] = entropies + if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits: + _outputs['logits'] = output_tensor + batch, _outputs = processor.unpack_packed_sequences(batch, _outputs) + logps = _outputs['logps'] + entropies = _outputs.get('entropies', None) + unpacked_logits = _outputs.get('logits', None) + except Exception as e: + # Data processing error (e.g. unpack_packed_sequences dimension mismatch). + # Must catch here inside the scheduler to prevent exception escaping + # and breaking PP P2P communication → NCCL hang. + if _is_fail_fast(): + raise + logger.warning('[nccl_safe] forward_step_func data processing error: ' + '%s: %s', + type(e).__name__, e) + logps = None + unpacked_logits = None + entropies = None + embeddings = None return output_tensor, partial( post_loss_function, inputs=batch, diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 819014eb8..fae4e4241 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -257,8 +257,8 @@ def reduce_loss(self, local_loss, local_count, logits, logps): grad_count = (count // cp_size).clamp(min=1) if cp_size > 1 else count return local_loss, grad_count, { 'loss': local_loss.detach(), - 'logits': logits.detach(), - 'logps': logps.detach(), + 'logits': logits.detach() if logits is not None else None, + 'logps': logps.detach() if logps is not None else None, 'num_tokens': count } diff --git a/tests/server/integration/test_nccl_safe_tinker_e2e.py b/tests/server/integration/test_nccl_safe_tinker_e2e.py index eec926dd4..aa94f99c4 100644 --- a/tests/server/integration/test_nccl_safe_tinker_e2e.py +++ b/tests/server/integration/test_nccl_safe_tinker_e2e.py @@ -184,7 +184,7 @@ def test_4_no_advantages(tc): def test_5_consecutive_bad(tc): for i in range(5): - datums = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=3+i) for _ in range(2)] + datums = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=3+i) for _ in range(4)] _, _, elapsed = run_forward_backward(tc, datums, f'TEST-5-{i+1}') if elapsed >= TIMEOUT: return False @@ -345,7 +345,7 @@ def test_17_single_datum(tc): return True def test_18_save_after_error(tc): - bad = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=2) for _ in range(2)] + bad = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=2) for _ in range(4)] _, _, elapsed = run_forward_backward(tc, bad, 'TEST-18-ERR') if elapsed >= TIMEOUT: return False diff --git a/tests/server/integration/test_nccl_safe_twinkle_e2e.py b/tests/server/integration/test_nccl_safe_twinkle_e2e.py index 63d12b231..c9cce48a6 100644 --- a/tests/server/integration/test_nccl_safe_twinkle_e2e.py +++ b/tests/server/integration/test_nccl_safe_twinkle_e2e.py @@ -103,6 +103,7 @@ def make_input_features( 'input_ids': input_ids, 'labels': labels, 'attention_mask': [1] * seq_len, + 'position_ids': list(range(seq_len)), }) if bad_old_logps_len is not None: @@ -212,7 +213,7 @@ def test_6_all_labels_masked(m): def test_7_consecutive_bad(m): for i in range(5): - inputs, old_logps, adv = make_input_features(batch_size=2, bad_old_logps_len=i+1) + inputs, old_logps, adv = make_input_features(batch_size=4, bad_old_logps_len=i+1) _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, f'TEST-7-{i+1}') if elapsed >= TIMEOUT: return False @@ -241,15 +242,15 @@ def test_9_final_health(m): return True def test_10_gradient_accumulation_error(m): - inputs, lp, adv = make_input_features(batch_size=2) + inputs, lp, adv = make_input_features(batch_size=4) ok, _, elapsed = run_forward_backward(m, inputs, lp, adv, 'TEST-10-GA1') if not ok or elapsed >= TIMEOUT: return False - bad_in, bad_lp, bad_adv = make_input_features(batch_size=2, bad_old_logps_len=3) + bad_in, bad_lp, bad_adv = make_input_features(batch_size=4, bad_old_logps_len=3) _, _, elapsed = run_forward_backward(m, bad_in, bad_lp, bad_adv, 'TEST-10-GA2-BAD') if elapsed >= TIMEOUT: return False - inputs, lp, adv = make_input_features(batch_size=2) + inputs, lp, adv = make_input_features(batch_size=4) ok, _, elapsed = run_forward_backward(m, inputs, lp, adv, 'TEST-10-GA3') if not ok or elapsed >= TIMEOUT: return False From 61a7e2b8fc72c557983f8e3fe6f6d8418189f923 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 26 Jun 2026 14:25:25 +0800 Subject: [PATCH 13/16] update test --- ...rver_config_4b_e2e_megatron_failfast1.yaml | 126 +----- tests/server/integration/e2e_helpers.py | 271 +++++++++++ tests/server/integration/test_dpo_e2e.py | 307 +++++++++++++ .../test_dpo_nccl_safe_megatron.py | 184 -------- tests/server/integration/test_dpo_pp_e2e.py | 191 -------- .../integration/test_dpo_tinker_pp_e2e.py | 149 ------ tests/server/integration/test_grpo_e2e.py | 428 ++++++++++++++++++ tests/server/integration/test_sft_e2e.py | 200 ++++++++ 8 files changed, 1207 insertions(+), 649 deletions(-) create mode 100644 tests/server/integration/e2e_helpers.py create mode 100644 tests/server/integration/test_dpo_e2e.py delete mode 100644 tests/server/integration/test_dpo_nccl_safe_megatron.py delete mode 100644 tests/server/integration/test_dpo_pp_e2e.py delete mode 100644 tests/server/integration/test_dpo_tinker_pp_e2e.py create mode 100644 tests/server/integration/test_grpo_e2e.py create mode 100644 tests/server/integration/test_sft_e2e.py diff --git a/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml index 848f4755d..64af5194d 100644 --- a/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml +++ b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml @@ -50,6 +50,7 @@ applications: tps_limit: 100000 adapter_config: adapter_timeout: 30 + adapter_max_lifetime: 36000 deployments: - name: ModelManagement autoscaling_config: @@ -98,131 +99,6 @@ applications: TWINKLE_TRUST_REMOTE_CODE: "1" TWINKLE_FAIL_FAST: "1" - - name: processor - route_prefix: /api/v1/processor - import_path: processor - args: - ncpu_proc_per_node: 2 - device_group: - name: model - ranks: 2 - device_type: CPU - device_mesh: - device_type: CPU - dp_size: 2 - deployments: - - name: ProcessorManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_FAIL_FAST: "1" -# Twinkle Server Configuration - E2E Test (4B model, Megatron backend) - -proxy_location: EveryNode - -http_options: - host: 0.0.0.0 - port: 9000 - -applications: - - - name: server - route_prefix: /api/v1 - import_path: server - args: - server_config: - per_token_model_limit: 20 - supported_models: - - Qwen/Qwen3.5-4B - deployments: - - name: TinkerCompatServer - max_ongoing_requests: 50 - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_FAIL_FAST: "1" - - - name: models-Qwen3.5-4B - route_prefix: /api/v1/model/Qwen/Qwen3.5-4B - import_path: model - args: - backend: megatron - model_id: "ms://Qwen/Qwen3.5-4B" - max_length: 10240 - nproc_per_node: 4 - device_group: - name: model - ranks: 4 - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 2 - pp_size: 2 - queue_config: - rps_limit: 100 - tps_limit: 100000 - adapter_config: - adapter_timeout: 30 - adapter_max_lifetime: 36000 - deployments: - - name: ModelManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "1" - TWINKLE_FAIL_FAST: "1" - - # - name: sampler-Qwen3.5-4B - # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen3.5-4B" - # nproc_per_node: 1 - # sampler_type: vllm - # engine_args: - # max_model_len: 4096 - # gpu_memory_utilization: 0.5 - # enable_lora: true - # max_loras: 5 - # logprobs_mode: processed_logprobs - # device_group: - # name: sampler - # ranks: 1 - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 - # tps_limit: 100000 - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "1" - # TWINKLE_FAIL_FAST: "1" - - name: processor route_prefix: /api/v1/processor import_path: processor diff --git a/tests/server/integration/e2e_helpers.py b/tests/server/integration/e2e_helpers.py new file mode 100644 index 000000000..aaaa87495 --- /dev/null +++ b/tests/server/integration/e2e_helpers.py @@ -0,0 +1,271 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Shared utilities for E2E integration tests. + +Provides reusable helpers for server health check, client initialization, +dataset preparation, and test assertions across all 12 test combinations +(2 backends x 2 clients x 3 tasks). +""" +from __future__ import annotations + +import math +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +# ═══════════════════════════════════════════════════════════════════════════ +# Constants +# ═══════════════════════════════════════════════════════════════════════════ + +BASE_MODEL = 'Qwen/Qwen3.5-4B' +MODEL_ID = f'ms://{BASE_MODEL}' +BASE_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +API_KEY = 'EMPTY_API_KEY' +TIMEOUT = 120 # seconds per operation before declaring hang +GRADIENT_ACCUMULATION_STEPS = 2 # Megatron requires GA >= 2 + + +def get_backend() -> str: + """Read backend type from environment variable.""" + backend = os.environ.get('TWINKLE_TEST_BACKEND', 'transformers').lower() + assert backend in ('transformers', 'megatron'), ( + f'Invalid TWINKLE_TEST_BACKEND={backend!r}, must be transformers or megatron') + return backend + + +# ═══════════════════════════════════════════════════════════════════════════ +# Server Health Check +# ═══════════════════════════════════════════════════════════════════════════ + +def wait_for_server(url: str = BASE_URL, timeout: int = 300) -> None: + """Wait for Twinkle server to become ready using Python requests.""" + import requests + + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f'{url}/-/routes', timeout=5) + if resp.status_code == 200: + elapsed = int(time.time() - start) + log(f'Server ready (waited {elapsed}s)') + return + except (requests.ConnectionError, requests.Timeout): + pass + time.sleep(5) + raise TimeoutError(f'Server not ready after {timeout}s at {url}') + + +def log(msg: str) -> None: + """Timestamped log output.""" + ts = time.strftime('%H:%M:%S') + print(f'[{ts}][E2E] {msg}', flush=True) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Dataset Factories +# ═══════════════════════════════════════════════════════════════════════════ + +def create_sft_dataset(data_slice=range(100)): + """Create SelfCognition SFT dataset (small slice for speed).""" + from twinkle.dataloader import DataLoader + from twinkle.dataset import Dataset, DatasetMeta + + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=data_slice)) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=256) + dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) + dataset.encode(batched=True, load_from_cache_file=False) + return dataset + + +def create_dpo_dataset(data_slice=range(50)): + """Create EmojiDPO dataset with positive/negative format.""" + from twinkle.dataset import Dataset, DatasetMeta + from twinkle.preprocessor import EmojiDPOProcessor + + dataset = Dataset(DatasetMeta('ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji', data_slice=data_slice)) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=1024) + dataset.map(EmojiDPOProcessor, init_args={'system': 'You are a helpful assistant.'}) + dataset.encode() + return dataset + + +def create_grpo_dataset(data_slice=range(50)): + """Create GSM8K dataset for GRPO training.""" + from twinkle.dataset import Dataset, DatasetMeta + from twinkle.preprocessor.llm import GSM8KProcessor + + system_prompt = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=data_slice)) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048, enable_thinking=False) + dataset.map(GSM8KProcessor(system=system_prompt)) + dataset.encode(add_generation_prompt=True) + return dataset + + +# ═══════════════════════════════════════════════════════════════════════════ +# Twinkle Client Factories +# ═══════════════════════════════════════════════════════════════════════════ + +def init_twinkle_client_session(): + """Initialize the Twinkle client session.""" + from twinkle import init_twinkle_client + return init_twinkle_client(base_url=BASE_URL, api_key=API_KEY) + + +def create_twinkle_sft_model(): + """Create Twinkle model configured for SFT (CrossEntropyLoss).""" + from peft import LoraConfig + from twinkle_client.model import MultiLoraTransformersModel + + model = MultiLoraTransformersModel(model_id=MODEL_ID) + model.add_adapter_to_model( + 'default', + LoraConfig(target_modules='all-linear'), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + model.set_loss('CrossEntropyLoss') + model.set_optimizer('Adam', lr=1e-4) + return model + + +def create_twinkle_dpo_model(): + """Create Twinkle model configured for DPO training.""" + from peft import LoraConfig + from twinkle_client.model import MultiLoraTransformersModel + + model = MultiLoraTransformersModel(model_id=MODEL_ID) + model.add_adapter_to_model( + 'default', + LoraConfig(target_modules='all-linear', r=8, lora_alpha=32, lora_dropout=0.05), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + model.set_loss('DPOLoss', beta=0.1, loss_type='sigmoid', reference_free=False, sft_weight=1.0) + model.add_metric('DPOMetric', beta=0.1) + model.set_optimizer('Adam', lr=1e-4) + return model + + +def create_twinkle_grpo_model(): + """Create Twinkle model configured for GRPO training.""" + from peft import LoraConfig + from twinkle_client.model import MultiLoraTransformersModel + + model = MultiLoraTransformersModel(model_id=MODEL_ID) + model.add_adapter_to_model( + 'default', + LoraConfig(target_modules='all-linear', r=8, lora_alpha=32, lora_dropout=0.05), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + model.set_loss('GRPOLoss', epsilon=0.2, beta=0.0) + model.set_optimizer('Adam', lr=2e-5) + model.set_processor('InputProcessor') + model.set_template('Qwen3_5Template', model_id=MODEL_ID) + return model + + +def create_twinkle_sampler(): + """Create Twinkle vLLM sampler.""" + from twinkle_client.sampler import vLLMSampler + + sampler = vLLMSampler(model_id=MODEL_ID) + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID) + return sampler + + +# ═══════════════════════════════════════════════════════════════════════════ +# Tinker Client Factories +# ═══════════════════════════════════════════════════════════════════════════ + +def init_tinker_client_session(): + """Initialize the Tinker client session and return ServiceClient.""" + from twinkle import init_tinker_client + init_tinker_client() + from tinker import ServiceClient + return ServiceClient(base_url=BASE_URL, api_key=API_KEY) + + +def create_tinker_training_client(rank: int = 8): + """Create Tinker LoRA training client.""" + service_client = init_tinker_client_session() + training_client = service_client.create_lora_training_client( + base_model=BASE_MODEL, + rank=rank, + ) + return training_client + + +# ═══════════════════════════════════════════════════════════════════════════ +# Data Processing Utilities +# ═══════════════════════════════════════════════════════════════════════════ + +def convert_tensors(batch: List[Dict[str, Any]]) -> None: + """Convert numpy/torch tensors to Python lists in-place for serialization.""" + import torch + + for row in batch: + for key in list(row.keys()): + val = row[key] + if isinstance(val, np.ndarray): + row[key] = val.tolist() + elif isinstance(val, torch.Tensor): + row[key] = val.cpu().numpy().tolist() + elif isinstance(val, dict): + for k2, v2 in val.items(): + if isinstance(v2, np.ndarray): + val[k2] = v2.tolist() + elif isinstance(v2, torch.Tensor): + val[k2] = v2.cpu().numpy().tolist() + + +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Reorganize batch into DP-safe interleaved format [pos_1, neg_1, pos_2, neg_2, ...].""" + result = [] + for row in batch: + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + pos_sample = {**base_fields, **row['positive']} + neg_sample = {**base_fields, **row['negative']} + result.append(pos_sample) + result.append(neg_sample) + return result + + +# ═══════════════════════════════════════════════════════════════════════════ +# Assertions / Pass Criteria +# ═══════════════════════════════════════════════════════════════════════════ + +def assert_no_timeout(elapsed: float, label: str, timeout: float = TIMEOUT) -> None: + """Assert operation completed within timeout (no NCCL hang).""" + assert elapsed < timeout, ( + f'[{label}] TIMEOUT ({elapsed:.1f}s > {timeout}s) — possible NCCL hang!') + + +def assert_loss_decreases(losses: List[float], label: str) -> None: + """Assert training loss shows a downward trend. + + Verifies that the average of the last 3 loss values is lower than + the average of the first 3 loss values. + """ + assert len(losses) >= 4, f'[{label}] Need at least 4 loss values, got {len(losses)}' + first_avg = sum(losses[:3]) / 3 + last_avg = sum(losses[-3:]) / 3 + assert last_avg < first_avg, ( + f'[{label}] Loss did NOT decrease: first_3_avg={first_avg:.4f} >= last_3_avg={last_avg:.4f}') + log(f'[{label}] Loss decreased: {first_avg:.4f} -> {last_avg:.4f}') + + +def assert_metrics_valid(metrics: Any, label: str) -> None: + """Assert metrics contain finite (non-NaN, non-Inf) values.""" + if metrics is None: + return + if isinstance(metrics, dict): + for key, val in metrics.items(): + if isinstance(val, (int, float)): + assert math.isfinite(val), ( + f'[{label}] Metric {key}={val} is not finite!') + log(f'[{label}] Metrics valid: {metrics}') diff --git a/tests/server/integration/test_dpo_e2e.py b/tests/server/integration/test_dpo_e2e.py new file mode 100644 index 000000000..d66250a52 --- /dev/null +++ b/tests/server/integration/test_dpo_e2e.py @@ -0,0 +1,307 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""DPO (Direct Preference Optimization) E2E integration tests. + +Tests DPO training across all 4 combinations: + - Twinkle client x (transformers | megatron) + - Tinker client x (transformers | megatron) + +Backend selection via env var TWINKLE_TEST_BACKEND (default: transformers). + +## How to run + + # Start server + python tests/server/start_e2e_server.py --config tests/server/config/server_config_4b_e2e.yaml + + # Run DPO tests + TWINKLE_TEST_GPU_E2E=1 TWINKLE_TEST_BACKEND=transformers pytest tests/server/integration/test_dpo_e2e.py -v +""" +from __future__ import annotations + +import os +import sys +import time +from typing import Any, Dict, List + +# Ensure project root is in sys.path for both pytest and direct execution +_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +from tests.server.integration.e2e_helpers import ( + BASE_MODEL, + BASE_URL, + GRADIENT_ACCUMULATION_STEPS, + MODEL_ID, + TIMEOUT, + assert_metrics_valid, + assert_no_timeout, + convert_tensors, + create_dpo_dataset, + create_tinker_training_client, + create_twinkle_dpo_model, + get_backend, + init_twinkle_client_session, + log, + prepare_dpo_batch, + wait_for_server, +) + +# ── Configuration ── +DPO_TRAIN_STEPS = 8 +DPO_BETA = 0.1 +DPO_SFT_WEIGHT = 1.0 + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: DPO via Twinkle client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_dpo_twinkle(): + """DPO training via Twinkle client (MultiLoraTransformersModel). + + Flow per step: + 1. forward_only (reference, disable_lora=True) -> ref_outputs + 2. forward_backward (DPO loss with ref_outputs) + 3. clip_grad_and_step + + Pass criteria: + - All steps complete without timeout (< 120s each) + - DPO metrics are valid (non-NaN/Inf) + - No NCCL hang + """ + import torch + from twinkle.dataloader import DataLoader + + backend = get_backend() + log(f'=== test_dpo_twinkle [backend={backend}] ===') + + wait_for_server() + init_twinkle_client_session() + + # Setup + dataset = create_dpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=4) + model = create_twinkle_dpo_model() + + log(f'Dataset: {len(dataset)} samples, DPO training {DPO_TRAIN_STEPS} steps') + + # Training loop + losses = [] + reward_margins = [] + step = 0 + for batch in dataloader: + if step >= DPO_TRAIN_STEPS: + break + + # Convert tensors for serialization + convert_tensors(batch) + + # Interleave positive/negative pairs + dpo_batch = prepare_dpo_batch(batch) + + # Step 1: Reference forward (base model, no LoRA) + log(f'[step {step + 1}] forward_only (reference)...') + t0 = time.time() + ref_outputs = model.forward_only(inputs=dpo_batch, disable_lora=True) + elapsed_fo = time.time() - t0 + assert_no_timeout(elapsed_fo, f'dpo_twinkle forward_only step {step}') + log(f'[step {step + 1}] forward_only OK ({elapsed_fo:.1f}s)') + + # Step 2: DPO forward_backward with ref_outputs + log(f'[step {step + 1}] forward_backward (DPO)...') + t1 = time.time() + model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs.result) + elapsed_fb = time.time() - t1 + assert_no_timeout(elapsed_fb, f'dpo_twinkle forward_backward step {step}') + log(f'[step {step + 1}] forward_backward OK ({elapsed_fb:.1f}s)') + + # Step 3: Optimizer step + model.clip_grad_and_step() + + # Log metrics every GA steps + if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: + metrics = model.calculate_metric(is_training=True) + if hasattr(metrics, 'result'): + assert_metrics_valid(metrics.result, f'dpo_twinkle step {step}') + # Track DPO loss and rewards + result = metrics.result + if isinstance(result, dict): + loss_val = result.get('loss') + if loss_val is not None: + losses.append(float(loss_val)) + reward_margin = result.get('rewards/margins') + if reward_margin is not None: + reward_margins.append(float(reward_margin)) + + step += 1 + + assert step == DPO_TRAIN_STEPS, f'Expected {DPO_TRAIN_STEPS} steps, completed {step}' + + # Verify DPO loss decreases + backend = get_backend() + if len(losses) >= 3 and not all(l == 0.0 for l in losses): + log(f'DPO losses: {["{:.4f}".format(l) for l in losses]}') + assert losses[-1] < losses[0], ( + f'[dpo_twinkle] DPO loss did NOT decrease: first={losses[0]:.4f} last={losses[-1]:.4f}') + log(f'[dpo_twinkle] Loss decreased: {losses[0]:.4f} -> {losses[-1]:.4f}') + elif backend == 'megatron': + log('[dpo_twinkle] Megatron: loss reports 0 (known behavior), verifying training completed OK') + + # Verify reward margins increase (DPO learns to prefer chosen) + if len(reward_margins) >= 3 and not all(abs(r) < 1e-6 for r in reward_margins): + log(f'Reward margins: {["{:.4f}".format(r) for r in reward_margins]}') + assert reward_margins[-1] > reward_margins[0], ( + f'[dpo_twinkle] Reward margins did NOT increase: first={reward_margins[0]:.4f} last={reward_margins[-1]:.4f}') + log(f'[dpo_twinkle] Reward margins increased: {reward_margins[0]:.4f} -> {reward_margins[-1]:.4f}') + + log(f'test_dpo_twinkle PASSED (backend={backend})') + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: DPO via Tinker client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_dpo_tinker(): + """DPO training via Tinker client (ServiceClient + forward/forward_backward). + + Flow per step: + 1. forward (cross_entropy, disable_lora=True) -> ref logps + 2. Attach ref_logps to datums + 3. forward_backward (importance_sampling) -> DPO loss + 4. optim_step + + Pass criteria: + - All steps complete without timeout (< 120s each) + - Metrics are valid + - No NCCL hang + """ + from tinker import types + from twinkle.dataloader import DataLoader + from twinkle.server.common import input_feature_to_datum + + backend = get_backend() + log(f'=== test_dpo_tinker [backend={backend}] ===') + + wait_for_server() + + # Setup + dataset = create_dpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=4) + training_client = create_tinker_training_client(rank=8) + + log(f'Dataset: {len(dataset)} samples, DPO training {DPO_TRAIN_STEPS} steps') + + # Training loop + losses = [] + reward_margins = [] + step = 0 + for batch in dataloader: + if step >= DPO_TRAIN_STEPS: + break + + # Convert tensors + convert_tensors(batch) + + # Interleave positive/negative pairs + dpo_batch = prepare_dpo_batch(batch) + + # Convert to Tinker Datums + input_datums = [input_feature_to_datum(row) for row in dpo_batch] + + # Step A: Reference forward (base model, disable_lora=True) + log(f'[step {step + 1}] forward (reference, disable_lora)...') + t0 = time.time() + ref_result = training_client.forward( + input_datums, + 'cross_entropy', + loss_fn_config={'disable_lora': True}, + ).result() + elapsed_ref = time.time() - t0 + assert_no_timeout(elapsed_ref, f'dpo_tinker reference step {step}') + log(f'[step {step + 1}] reference forward OK ({elapsed_ref:.1f}s)') + + # Step B: Attach ref_logps to datums + for datum, ref_out in zip(input_datums, ref_result.loss_fn_outputs): + ref_logprobs_np = np.array(ref_out['logprobs'].tolist(), dtype=np.float32) + datum.loss_fn_inputs['ref_logps'] = types.TensorData.from_numpy(ref_logprobs_np) + + # Step C: DPO forward_backward + log(f'[step {step + 1}] forward_backward (DPO)...') + t1 = time.time() + fwdbwd_result = training_client.forward_backward( + input_datums, + 'importance_sampling', + loss_fn_config={ + 'dpo_beta': DPO_BETA, + 'dpo_sft_weight': DPO_SFT_WEIGHT, + }, + ).result() + elapsed_fb = time.time() - t1 + assert_no_timeout(elapsed_fb, f'dpo_tinker forward_backward step {step}') + log(f'[step {step + 1}] forward_backward OK ({elapsed_fb:.1f}s)') + + # Step D: Optimizer step + optim_result = training_client.optim_step( + types.AdamParams(learning_rate=1e-4) + ).result() + + if optim_result.metrics: + assert_metrics_valid(optim_result.metrics, f'dpo_tinker step {step}') + log(f'[step {step + 1}] metrics={optim_result.metrics}') + # Track loss and reward margins + loss_val = optim_result.metrics.get('loss') + if loss_val is not None: + losses.append(float(loss_val)) + reward_margin = optim_result.metrics.get('rewards/margins') + if reward_margin is not None: + reward_margins.append(float(reward_margin)) + + step += 1 + + assert step == DPO_TRAIN_STEPS, f'Expected {DPO_TRAIN_STEPS} steps, completed {step}' + + # Verify DPO loss decreases + if len(losses) >= 3 and not all(l == 0.0 for l in losses): + log(f'DPO losses: {["{:.4f}".format(l) for l in losses]}') + assert losses[-1] < losses[0], ( + f'[dpo_tinker] DPO loss did NOT decrease: first={losses[0]:.4f} last={losses[-1]:.4f}') + log(f'[dpo_tinker] Loss decreased: {losses[0]:.4f} -> {losses[-1]:.4f}') + elif backend == 'megatron': + log('[dpo_tinker] Megatron: loss reports 0 (known behavior), verifying training completed OK') + + # Verify reward margins increase + if len(reward_margins) >= 3 and not all(abs(r) < 1e-6 for r in reward_margins): + log(f'Reward margins: {["{:.4f}".format(r) for r in reward_margins]}') + assert reward_margins[-1] > reward_margins[0], ( + f'[dpo_tinker] Reward margins did NOT increase: first={reward_margins[0]:.4f} last={reward_margins[-1]:.4f}') + log(f'[dpo_tinker] Reward margins increased: {reward_margins[0]:.4f} -> {reward_margins[-1]:.4f}') + + log(f'test_dpo_tinker PASSED (backend={backend})') + + +# ── Direct execution ── + +def main() -> int: + log('Running DPO E2E tests directly...') + try: + test_dpo_twinkle() + test_dpo_tinker() + log('ALL DPO TESTS PASSED') + return 0 + except Exception as e: + log(f'FAILED: {e}') + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/integration/test_dpo_nccl_safe_megatron.py b/tests/server/integration/test_dpo_nccl_safe_megatron.py deleted file mode 100644 index 96831bf78..000000000 --- a/tests/server/integration/test_dpo_nccl_safe_megatron.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""DPO NCCL-safe verification test for Megatron backend. - -Tests that DPO training (forward_only + forward_backward) does NOT cause -NCCL hang when errors occur, verifying the nccl_safe_megatron fix. - -Prerequisites: - 1. Ray cluster running with GPUs - 2. Twinkle server started with Megatron backend and TWINKLE_FAIL_FAST=0 - -Usage (direct): - python tests/server/integration/test_dpo_nccl_safe_megatron.py -""" -from __future__ import annotations - -import os -import sys -import time - -import numpy as np - -SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') -BASE_MODEL = 'Qwen/Qwen3.5-4B' -TIMEOUT = 120 - - -def log(msg): - print(f'[DPO-Megatron] {msg}', flush=True) - - -def wait_for_server(url, timeout=300): - import requests - start = time.time() - while time.time() - start < timeout: - try: - resp = requests.get(f'{url}/-/routes', timeout=5) - if resp.status_code == 200: - log(f'Server ready ({int(time.time() - start)}s)') - return True - except Exception: - pass - time.sleep(5) - raise TimeoutError(f'Server not ready after {timeout}s') - - -def init_dpo_client(): - """Initialize Twinkle client for DPO training.""" - from twinkle_client import init_twinkle_client - from twinkle_client.model import MultiLoraTransformersModel - from peft import LoraConfig - - init_twinkle_client(base_url=SERVER_URL, api_key='EMPTY_TOKEN') - - model = MultiLoraTransformersModel(model_id=f'ms://{BASE_MODEL}') - model.add_adapter_to_model( - adapter_name='dpo-test', - config=LoraConfig(r=16, target_modules=['q_proj', 'v_proj']), - gradient_accumulation_steps=1, - ) - model.set_loss('DPOLoss', init_args={'beta': 0.1}) - model.set_optimizer('Adam', lr=1e-5) - model.set_template('Qwen3_5Template') - model.set_processor('InputProcessor', padding_side='right') - log('DPO client configured') - return model - - -def make_dpo_batch(batch_size=4, seq_len=64, completion_len=32): - """Create DPO batch: interleaved chosen/rejected pairs.""" - prompt_len = seq_len - completion_len - all_inputs = [] - - # DPO requires even batch (chosen/rejected pairs) - for i in range(batch_size): - input_ids = list(range(1, seq_len + 1)) - labels = [-100] * prompt_len + list(range(100, 100 + completion_len)) - all_inputs.append({ - 'input_ids': input_ids, - 'labels': labels, - 'attention_mask': [1] * seq_len, - }) - - return all_inputs - - -def run_dpo_step(model, test_name, *, bad_ref_logps=False): - """Execute a full DPO step: forward_only (ref) + forward_backward (policy).""" - batch = make_dpo_batch(batch_size=4) - - # Step 1: forward_only for reference logps - log(f'[{test_name}] forward_only (reference)...') - start = time.time() - try: - ref_result = model.forward_only(inputs=batch, disable_lora=True) - elapsed = time.time() - start - log(f'[{test_name}] forward_only OK ({elapsed:.1f}s)') - except Exception as e: - elapsed = time.time() - start - log(f'[{test_name}] forward_only FAILED ({elapsed:.1f}s): {e}') - if elapsed > TIMEOUT: - log(f'[{test_name}] TIMEOUT! NCCL HANG detected!') - return False - return True # Error was caught, not a hang - - # Step 2: forward_backward for policy training - log(f'[{test_name}] forward_backward (policy)...') - start = time.time() - try: - # For DPO, we need ref_logps from the reference forward - kwargs = {} - if hasattr(ref_result, 'result') and ref_result.result: - # Extract ref_logps from forward_only result - pass # In real DPO, client passes ref_logps - result = model.forward_backward(inputs=batch, **kwargs) - elapsed = time.time() - start - log(f'[{test_name}] forward_backward OK ({elapsed:.1f}s)') - except Exception as e: - elapsed = time.time() - start - log(f'[{test_name}] forward_backward FAILED ({elapsed:.1f}s): {e}') - if elapsed > TIMEOUT: - log(f'[{test_name}] TIMEOUT! NCCL HANG detected!') - return False - return True # Error was caught, not a hang - - # Step 3: optimizer step - try: - model.clip_grad_and_step() - log(f'[{test_name}] clip_grad_and_step OK') - except Exception as e: - log(f'[{test_name}] clip_grad_and_step FAILED: {e}') - - return True - - -def main(): - log('=' * 60) - log('DPO NCCL-Safe Verification - Megatron Backend') - log('=' * 60) - log(f'Server URL: {SERVER_URL}') - log(f'TWINKLE_FAIL_FAST = {os.getenv("TWINKLE_FAIL_FAST", "1 (default)")}') - - wait_for_server(SERVER_URL) - model = init_dpo_client() - - results = [] - - # Test 1: Normal DPO training (should work) - passed = run_dpo_step(model, 'TEST-1-NORMAL-DPO') - results.append(('TEST-1: Normal DPO', passed)) - - # Test 2: Multiple consecutive DPO steps - for i in range(3): - passed = run_dpo_step(model, f'TEST-2-CONSECUTIVE-{i+1}') - results.append((f'TEST-2-{i+1}: Consecutive DPO', passed)) - - # Test 3: forward_only then forward_backward rapidly - passed = run_dpo_step(model, 'TEST-3-RAPID') - results.append(('TEST-3: Rapid DPO', passed)) - - # Test 4: Health check after all DPO operations - log('[TEST-4] Final health check - forward_backward...') - batch = make_dpo_batch(batch_size=4) - start = time.time() - try: - model.forward_backward(inputs=batch) - elapsed = time.time() - start - log(f'[TEST-4] OK ({elapsed:.1f}s)') - results.append(('TEST-4: Final health', True)) - except Exception as e: - elapsed = time.time() - start - log(f'[TEST-4] FAILED ({elapsed:.1f}s): {e}') - results.append(('TEST-4: Final health', elapsed < TIMEOUT)) - - # Summary - log(f'\n{"=" * 60}\nRESULTS SUMMARY\n{"=" * 60}') - all_passed = all(p for _, p in results) - for name, status in results: - log(f' [{"PASS" if status else "FAIL"}] {name}') - log(f'\n{"ALL" if all_passed else "SOME"} {len(results)} TESTS {"PASSED" if all_passed else "FAILED"}!') - return 0 if all_passed else 1 - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/tests/server/integration/test_dpo_pp_e2e.py b/tests/server/integration/test_dpo_pp_e2e.py deleted file mode 100644 index 425410e5e..000000000 --- a/tests/server/integration/test_dpo_pp_e2e.py +++ /dev/null @@ -1,191 +0,0 @@ -"""DPO training E2E test on Megatron PP=2 backend. - -Reproduces the PP deadlock where forward_only succeeds but forward_backward -hangs due to nccl_safe loss skip breaking pipeline P2P communication. - -Flow (mirrors cookbook/client/twinkle/modelscope/dpo.py): - Phase 1 — Setup: configure model with DPO loss - Phase 2 — forward_only (ref_outputs): base model inference, no LoRA - Phase 3 — forward_backward (DPO training): triggers PP P2P communication - -No sampler required. Server config: server_config_4b_dpo_megatron.yaml - -## How to run - - # 1. Start server (no sampler, 4 GPU model only) - python tests/server/start_e2e_server.py \\ - --config tests/server/config/server_config_4b_dpo_megatron.yaml - - # 2. Run DPO PP test - TWINKLE_TEST_GPU_E2E=1 python -u tests/server/integration/test_dpo_pp_e2e.py - -Expected: forward_only succeeds, forward_backward either succeeds or -reproduces the PP deadlock (504 timeout / NCCL hang). -""" -from __future__ import annotations - -import dotenv - -dotenv.load_dotenv('.env') - -import os # noqa: E402 -import sys # noqa: E402 -import time # noqa: E402 -from typing import Any, Dict, List # noqa: E402 - -import numpy as np # noqa: E402 -import pytest # noqa: E402 -import torch # noqa: E402 -from peft import LoraConfig # noqa: E402 - -pytestmark = pytest.mark.skipif( - os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', - reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', -) - -from twinkle import get_logger, init_twinkle_client # noqa: E402 -from twinkle.dataloader import DataLoader # noqa: E402 -from twinkle.dataset import Dataset, DatasetMeta # noqa: E402 -from twinkle.preprocessor import EmojiDPOProcessor # noqa: E402 -from twinkle_client.model import MultiLoraTransformersModel # noqa: E402 - -logger = get_logger() - -# ── Configuration ── -BASE_MODEL = 'Qwen/Qwen3.5-4B' -BASE_URL = 'http://localhost:9000' -API_KEY = 'EMPTY_API_KEY' -DATASET_ID = 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' - -BATCH_SIZE = 4 -GRADIENT_ACCUMULATION_STEPS = 2 -DPO_BETA = 0.1 -SFT_WEIGHT = 1.0 -LOSS_TYPE = 'sigmoid' -MAX_LENGTH = 2048 -SYSTEM_PROMPT = 'You are a helpful assistant.' -DPO_TRAIN_STEPS = 4 # small number, just enough to trigger the bug -FORWARD_BACKWARD_TIMEOUT = 120 # seconds to wait before declaring hang - - -def _create_dpo_dataset() -> Dataset: - """Create DPO dataset with positive/negative format (small slice for speed).""" - dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=MAX_LENGTH) - dataset.map(EmojiDPOProcessor, init_args={'system': SYSTEM_PROMPT}) - dataset.encode() - return dataset - - -def _prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Interleave positive/negative pairs: [pos_1, neg_1, pos_2, neg_2, ...]. - - This DP-safe interleaving ensures each DP worker gets complete pairs - after slicing. - """ - result = [] - for row in batch: - base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} - pos_sample = {**base_fields, **row['positive']} - neg_sample = {**base_fields, **row['negative']} - result.append(pos_sample) - result.append(neg_sample) - return result - - -def _convert_tensors(batch: List[Dict[str, Any]]) -> None: - """Convert numpy/torch tensors to lists for serialization (in-place).""" - for row in batch: - for key in row: - if isinstance(row[key], np.ndarray): - row[key] = row[key].tolist() - elif isinstance(row[key], torch.Tensor): - row[key] = row[key].cpu().numpy().tolist() - - -def _configure_dpo_model() -> MultiLoraTransformersModel: - """Configure model with DPO loss, optimizer, and LoRA adapter.""" - model = MultiLoraTransformersModel(model_id=f'ms://{BASE_MODEL}') - model.add_adapter_to_model( - 'default', - LoraConfig(target_modules='all-linear', r=8, lora_alpha=32, lora_dropout=0.05), - gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, - ) - model.set_template('Qwen3_5Template') - model.set_processor('InputProcessor', padding_side='right') - model.set_loss('DPOLoss', beta=DPO_BETA, loss_type=LOSS_TYPE, reference_free=False, sft_weight=SFT_WEIGHT) - model.add_metric('DPOMetric', beta=DPO_BETA) - model.set_optimizer('Adam', lr=1e-4) - return model - - -def main() -> int: - client = init_twinkle_client(base_url=BASE_URL, api_key=API_KEY) - logger.info('Available models:') - for m in client.get_server_capabilities().supported_models: - logger.info(f' - {m.model_name}') - - # ── Phase 1: Setup ── - logger.info('=' * 60) - logger.info('Phase 1: Setup — configure DPO model + load dataset') - logger.info('=' * 60) - - dataset = _create_dpo_dataset() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - model = _configure_dpo_model() - - logger.info('Model and dataset ready. Starting DPO training loop...') - logger.info(f'DPO config: beta={DPO_BETA}, loss_type={LOSS_TYPE}, sft_weight={SFT_WEIGHT}') - logger.info(f'Training {DPO_TRAIN_STEPS} steps with GA={GRADIENT_ACCUMULATION_STEPS}') - - # ── Phase 2+3: DPO training loop ── - logger.info('=' * 60) - logger.info('Phase 2+3: DPO training (forward_only + forward_backward)') - logger.info('=' * 60) - - step = 0 - for batch in dataloader: - _convert_tensors(batch) - dpo_batch = _prepare_dpo_batch(batch) - - # Phase 2: forward_only — get reference outputs (base model, no LoRA) - logger.info(f'[Step {step + 1}] forward_only (ref_outputs) ...') - t0 = time.time() - ref_outputs = model.forward_only(inputs=dpo_batch, disable_lora=True) - t_fo = time.time() - t0 - logger.info(f'[Step {step + 1}] forward_only OK ({t_fo:.1f}s)') - - # Phase 3: forward_backward — DPO training with ref_outputs - logger.info(f'[Step {step + 1}] forward_backward (DPO loss) ...') - t0 = time.time() - model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs.result) - t_fb = time.time() - t0 - logger.info(f'[Step {step + 1}] forward_backward OK ({t_fb:.1f}s)') - - model.clip_grad_and_step() - - # Log metrics every GA steps - step += 1 - if step % GRADIENT_ACCUMULATION_STEPS == 0: - metrics = model.calculate_metric(is_training=True) - logger.info(f'[Optim step {step // GRADIENT_ACCUMULATION_STEPS}] {metrics}') - - if step >= DPO_TRAIN_STEPS: - break - - logger.info('=' * 60) - logger.info('ALL DPO PHASES PASSED') - logger.info('=' * 60) - return 0 - - -# ── pytest entry point ── - -def test_dpo_pp_e2e(): - """Pytest-collected entry point for the DPO PP E2E suite.""" - rc = main() - assert rc == 0, 'DPO PP E2E test failed' - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/tests/server/integration/test_dpo_tinker_pp_e2e.py b/tests/server/integration/test_dpo_tinker_pp_e2e.py deleted file mode 100644 index ca872be56..000000000 --- a/tests/server/integration/test_dpo_tinker_pp_e2e.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Tinker client DPO test on Megatron PP=2 backend. - -Reproduces the tinker_forward_only ragged logps error. - -Usage: - # 1. Start server (PP=2, no sampler) - python tests/server/start_e2e_server.py \ - --config tests/server/config/server_config_4b_dpo_megatron.yaml - - # 2. Run test - TWINKLE_TEST_GPU_E2E=1 python -u tests/server/integration/test_dpo_tinker_pp_e2e.py -""" -from __future__ import annotations - -import os -import sys -import time - -import numpy as np -import torch - -SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') -BASE_MODEL = 'Qwen/Qwen3.5-4B' - - -def log(msg): - ts = time.strftime('%Y-%m-%d %H:%M:%S') - print(f'[{ts}][INFO:twinkle] {msg}', flush=True) - - -def main(): - if not os.environ.get('TWINKLE_TEST_GPU_E2E'): - log('SKIP: set TWINKLE_TEST_GPU_E2E=1 to run') - return 0 - - # ── Init tinker client ──────────────────────────────────────────── - from tinker import types - from twinkle import init_tinker_client - init_tinker_client() - from tinker import ServiceClient - - service_client = ServiceClient(base_url=SERVER_URL, api_key='EMPTY_TOKEN') - training_client = service_client.create_lora_training_client( - base_model=BASE_MODEL, rank=8, - ) - log(f'Tinker training client ready (model={BASE_MODEL})') - - # ── Prepare DPO dataset ─────────────────────────────────────────── - from twinkle.dataset import Dataset, DatasetMeta - from twinkle.preprocessor import EmojiDPOProcessor - from twinkle.server.common import input_feature_to_datum - - log('Loading DPO dataset (10 samples)...') - dataset = Dataset(DatasetMeta(f'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji', data_slice=range(10))) - dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=1024) - dataset.map(EmojiDPOProcessor, init_args={'system': 'You are a helpful assistant.'}) - dataset.encode() - - # Build interleaved [pos, neg, pos, neg] batch - batch = list(dataset)[:4] - dpo_batch = [] - for row in batch: - for key in list(row.keys()): - if isinstance(row[key], np.ndarray): - row[key] = row[key].tolist() - elif isinstance(row[key], torch.Tensor): - row[key] = row[key].cpu().numpy().tolist() - base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} - dpo_batch.append({**base_fields, **row['positive']}) - dpo_batch.append({**base_fields, **row['negative']}) - log(f'DPO batch: {len(dpo_batch)} samples (interleaved pos/neg)') - - # Convert to Tinker Datums - input_datums = [input_feature_to_datum(row) for row in dpo_batch] - seq_lens = [d.loss_fn_inputs['target_tokens'].to_numpy().shape[0] for d in input_datums] - log(f'Datum seq_lens: {seq_lens}') - - # ── Step 1: Reference forward (tinker_forward_only) ─────────────── - log('=' * 60) - log('Step 1: tinker forward (reference, disable_lora=True)...') - log('=' * 60) - start = time.time() - try: - ref_result = training_client.forward( - input_datums, 'cross_entropy', - loss_fn_config={'disable_lora': True}, - ).result() - elapsed = time.time() - start - log(f'Step 1 OK ({elapsed:.1f}s), {len(ref_result.loss_fn_outputs)} outputs') - - # Show logprobs shapes - for i, out in enumerate(ref_result.loss_fn_outputs): - lp = out.get('logprobs') - if lp is not None: - arr = np.array(lp.tolist()) - log(f' output[{i}] logprobs shape={arr.shape}') - except Exception as e: - elapsed = time.time() - start - log(f'Step 1 FAILED ({elapsed:.1f}s): {type(e).__name__}: {e}') - if elapsed > 120: - log('TIMEOUT — likely NCCL hang!') - log('Check server log for traceback') - return 1 - - # ── Step 2: Attach ref_logps to datums ──────────────────────────── - log('Step 2: Attaching ref_logps to datums...') - for datum, ref_out in zip(input_datums, ref_result.loss_fn_outputs): - ref_logprobs_np = np.array(ref_out['logprobs'].tolist(), dtype=np.float32) - datum.loss_fn_inputs['ref_logps'] = types.TensorData.from_numpy(ref_logprobs_np) - log(f' ref_logps shape={ref_logprobs_np.shape}') - - # ── Step 3: DPO forward_backward (tinker_forward_backward) ──────── - log('=' * 60) - log('Step 3: tinker forward_backward (DPO loss)...') - log('=' * 60) - start = time.time() - try: - fwdbwd_result = training_client.forward_backward( - input_datums, 'importance_sampling', - loss_fn_config={'dpo_beta': 0.1, 'dpo_sft_weight': 1.0}, - ).result() - elapsed = time.time() - start - log(f'Step 3 OK ({elapsed:.1f}s)') - except Exception as e: - elapsed = time.time() - start - log(f'Step 3 FAILED ({elapsed:.1f}s): {type(e).__name__}: {e}') - if elapsed > 120: - log('TIMEOUT — likely NCCL hang!') - return 1 - - # ── Step 4: Optimizer step ──────────────────────────────────────── - log('Step 4: optim_step...') - try: - optim_result = training_client.optim_step( - types.AdamParams(learning_rate=1e-4) - ).result() - log(f'Step 4 OK, metrics={optim_result.metrics}') - except Exception as e: - log(f'Step 4 FAILED: {e}') - return 1 - - log('=' * 60) - log('ALL TINKER DPO PHASES PASSED') - log('=' * 60) - return 0 - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/tests/server/integration/test_grpo_e2e.py b/tests/server/integration/test_grpo_e2e.py new file mode 100644 index 000000000..909c998b7 --- /dev/null +++ b/tests/server/integration/test_grpo_e2e.py @@ -0,0 +1,428 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""GRPO (Group Relative Policy Optimization) E2E integration tests. + +Tests GRPO training across all 4 combinations: + - Twinkle client x (transformers | megatron) + - Tinker client x (transformers | megatron) + +Backend selection via env var TWINKLE_TEST_BACKEND (default: transformers). +Requires sampler service to be running. + +## How to run + + # Start server (must include sampler) + python tests/server/start_e2e_server.py --config tests/server/config/server_config_4b_e2e.yaml + + # Run GRPO tests + TWINKLE_TEST_GPU_E2E=1 TWINKLE_TEST_BACKEND=transformers pytest tests/server/integration/test_grpo_e2e.py -v +""" +from __future__ import annotations + +import os +import re +import sys +import time +from typing import Any, Dict, List, Tuple + +# Ensure project root is in sys.path for both pytest and direct execution +_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +from tests.server.integration.e2e_helpers import ( + BASE_MODEL, + BASE_URL, + GRADIENT_ACCUMULATION_STEPS, + MODEL_ID, + TIMEOUT, + assert_metrics_valid, + assert_no_timeout, + create_grpo_dataset, + create_tinker_training_client, + create_twinkle_grpo_model, + create_twinkle_sampler, + get_backend, + init_twinkle_client_session, + log, + wait_for_server, +) + +# ── Configuration ── +GRPO_TRAIN_STEPS = 2 +NUM_GENERATIONS = 4 +MAX_NEW_TOKENS = 512 +LEARNING_RATE = 2e-5 +TEMPERATURE = 1.0 + +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') + + +# ═══════════════════════════════════════════════════════════════════════════ +# Reward Functions (lightweight versions for E2E testing) +# ═══════════════════════════════════════════════════════════════════════════ + +def compute_rewards(trajectories: List[Dict[str, Any]]) -> Tuple[List[float], List[float]]: + """Compute accuracy and brevity rewards for GSM8K. + + Returns (total_rewards, accuracy_rewards). + """ + from twinkle.reward import GSM8KAccuracyReward + + accuracy_reward_fn = GSM8KAccuracyReward() + accuracy_rewards = accuracy_reward_fn(trajectories) + + # Simple brevity reward + brevity_rewards = [] + for traj in trajectories: + messages = traj.get('messages', []) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + has_answer = bool( + re.search(r'\\boxed\{[^}]+\}', completion) + or re.search(r'####\s*[\-\d,\.]+', completion) + ) + if not has_answer: + brevity_rewards.append(0.0) + else: + length = len(completion) + brevity_rewards.append(max(0.0, 1.0 - max(0, length - 200) / 3000)) + + total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] + return total_rewards, accuracy_rewards + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: GRPO via Twinkle client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_grpo_twinkle(): + """GRPO training via Twinkle client (model.save + sampler + forward_backward). + + Flow per step: + 1. model.save(is_sampler=True) -> adapter_uri + 2. sampler.sample(inputs, adapter_uri, num_samples=N) + 3. Compute rewards + advantages + 4. model.forward_backward(inputs, advantages, old_logps) + 5. model.clip_grad_and_step() + + Pass criteria: + - Sampling returns non-empty completions + - forward_backward completes without timeout + - Metrics are valid (non-NaN/Inf) + - No NCCL hang + """ + from twinkle.advantage import GRPOAdvantage + from twinkle.dataloader import DataLoader + + backend = get_backend() + log(f'=== test_grpo_twinkle [backend={backend}] ===') + + wait_for_server() + init_twinkle_client_session() + + # Setup + dataset = create_grpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + model = create_twinkle_grpo_model() + sampler = create_twinkle_sampler() + advantage_fn = GRPOAdvantage() + + sampling_params = { + 'max_tokens': MAX_NEW_TOKENS, + 'temperature': TEMPERATURE, + 'top_p': 0.95, + 'num_samples': NUM_GENERATIONS, + 'logprobs': 1, + } + + log(f'Dataset: {len(dataset)} samples, GRPO training {GRPO_TRAIN_STEPS} steps') + log(f'NUM_GENERATIONS={NUM_GENERATIONS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}') + + # Training loop + current_adapter_uri = None + step = 0 + + for batch in dataloader: + if step >= GRPO_TRAIN_STEPS: + break + + prompts = batch if isinstance(batch, list) else [batch] + + # Step 1: Save weights for sampler + log(f'[step {step + 1}] Saving weights for sampler...') + t0 = time.time() + save_result = model.save(name='grpo-e2e-weights', save_optimizer=False, is_sampler=True) + current_adapter_uri = save_result.twinkle_path + elapsed_save = time.time() - t0 + log(f'[step {step + 1}] Weights saved ({elapsed_save:.1f}s): {current_adapter_uri}') + + # Step 2: Sample completions + log(f'[step {step + 1}] Sampling {len(prompts)} prompts x {NUM_GENERATIONS} generations...') + t1 = time.time() + sample_responses = sampler.sample( + inputs=prompts, + sampling_params=sampling_params, + adapter_uri=current_adapter_uri, + ) + elapsed_sample = time.time() - t1 + assert_no_timeout(elapsed_sample, f'grpo_twinkle sampling step {step}') + + # Collect sequences + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] + + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) + + assert len(all_input_data) > 0, f'[step {step + 1}] Sampling returned no completions!' + log(f'[step {step + 1}] Got {len(all_input_data)} completions ({elapsed_sample:.1f}s)') + + # Step 3: Compute rewards + advantages + total_rewards, accuracy_rewards = compute_rewards(all_input_data) + advantages = advantage_fn( + total_rewards, + num_generations=NUM_GENERATIONS, + scale='group', + ).tolist() + + # Skip if all advantages are zero (no learning signal) + if all(abs(a) < 1e-8 for a in advantages): + log(f'[step {step + 1}] All advantages zero, skipping (still counts as success)') + step += 1 + continue + + # Step 4: forward_backward with GRPO loss + log(f'[step {step + 1}] forward_backward (GRPO)...') + t2 = time.time() + model.forward_backward( + inputs=all_input_data, + advantages=advantages, + old_logps=all_old_logps, + ) + elapsed_fb = time.time() - t2 + assert_no_timeout(elapsed_fb, f'grpo_twinkle forward_backward step {step}') + log(f'[step {step + 1}] forward_backward OK ({elapsed_fb:.1f}s)') + + # Step 5: Optimizer step + model.clip_grad_and_step() + + # Log metrics + metrics = model.calculate_metric(is_training=True) + if hasattr(metrics, 'result'): + assert_metrics_valid(metrics.result, f'grpo_twinkle step {step}') + + step += 1 + log(f'[step {step}] Complete. accuracy_rewards={accuracy_rewards[:4]}') + + assert step >= GRPO_TRAIN_STEPS, f'Expected {GRPO_TRAIN_STEPS} steps, completed {step}' + log(f'test_grpo_twinkle PASSED (backend={backend})') + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: GRPO via Tinker client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_grpo_tinker(): + """GRPO training via Tinker client (save_weights_and_get_sampling_client). + + Flow per step: + 1. save_weights_and_get_sampling_client() -> sampling_client + 2. sampling_client.sample(prompt, params, num_samples=N) + 3. Compute rewards + advantages + 4. Build Datums with logprobs + advantages + 5. forward_backward (importance_sampling) + 6. optim_step + + Pass criteria: + - Sampling returns non-empty completions + - forward_backward completes without timeout + - Metrics are valid + - No NCCL hang + """ + from tinker import types + from twinkle.advantage import GRPOAdvantage + from twinkle.dataloader import DataLoader + from twinkle.template import Qwen3_5Template + + backend = get_backend() + log(f'=== test_grpo_tinker [backend={backend}] ===') + + wait_for_server() + + # Setup + dataset = create_grpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + training_client = create_tinker_training_client(rank=8) + template = Qwen3_5Template(model_id=MODEL_ID) + advantage_fn = GRPOAdvantage() + + sampling_params = types.SamplingParams( + max_tokens=MAX_NEW_TOKENS, + temperature=TEMPERATURE, + top_p=0.95, + ) + + log(f'Dataset: {len(dataset)} samples, GRPO training {GRPO_TRAIN_STEPS} steps') + + # Training loop + sampling_client = None + step = 0 + + for batch in dataloader: + if step >= GRPO_TRAIN_STEPS: + break + + prompts = batch if isinstance(batch, list) else [batch] + + # Step 1: Save weights and get sampling client + log(f'[step {step + 1}] Saving weights for sampler...') + t0 = time.time() + sampling_client = training_client.save_weights_and_get_sampling_client() + elapsed_save = time.time() - t0 + log(f'[step {step + 1}] Sampling client ready ({elapsed_save:.1f}s)') + + # Step 2: Sample completions + log(f'[step {step + 1}] Sampling...') + t1 = time.time() + all_sequences = [] + all_user_data = [] + + for prompt_feature in prompts: + input_ids = prompt_feature['input_ids'] + if hasattr(input_ids, 'tolist'): + input_ids = input_ids.tolist() + prompt = types.ModelInput.from_ints(input_ids) + future = sampling_client.sample( + prompt=prompt, + sampling_params=sampling_params, + num_samples=NUM_GENERATIONS, + ) + result = future.result() + for _ in range(NUM_GENERATIONS): + all_user_data.append(prompt_feature.get('user_data', [])) + all_sequences.extend(result.sequences) + + elapsed_sample = time.time() - t1 + assert_no_timeout(elapsed_sample, f'grpo_tinker sampling step {step}') + assert len(all_sequences) > 0, f'[step {step + 1}] Sampling returned no sequences!' + log(f'[step {step + 1}] Got {len(all_sequences)} sequences ({elapsed_sample:.1f}s)') + + # Step 3: Build trajectories and compute rewards + trajectories = [] + completion_lengths = [] + + for idx, seq in enumerate(all_sequences): + decoded_text = template.decode(seq.tokens, skip_special_tokens=True) + trajectories.append({ + 'messages': [ + {'role': 'system', 'content': SYSTEM_PROMPT}, + {'role': 'user', 'content': 'Math problem'}, + {'role': 'assistant', 'content': decoded_text}, + ], + 'user_data': all_user_data[idx], + }) + completion_lengths.append(len(seq.tokens)) + + total_rewards, accuracy_rewards = compute_rewards(trajectories) + + # Step 4: Compute advantages + advantages = advantage_fn( + total_rewards, + num_generations=NUM_GENERATIONS, + scale='group', + ).tolist() + + if all(abs(a) < 1e-8 for a in advantages): + log(f'[step {step + 1}] All advantages zero, skipping') + step += 1 + continue + + # Step 5: Build training Datums + training_data = [] + for i, seq in enumerate(all_sequences): + prompt_feature = prompts[i // NUM_GENERATIONS] + prompt_ids = prompt_feature['input_ids'] + if hasattr(prompt_ids, 'tolist'): + prompt_ids = prompt_ids.tolist() + + sampled_tokens = list(seq.tokens) + logprobs = seq.logprobs if seq.logprobs else [0.0] * len(sampled_tokens) + advantage = float(advantages[i]) + + ob_len = len(prompt_ids) - 1 + input_tokens = prompt_ids + sampled_tokens[:-1] + target_tokens = [0] * ob_len + sampled_tokens + weights = [0] * ob_len + [1] * len(sampled_tokens) + padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens) + padded_logprobs = [0.0] * ob_len + list(logprobs) + + datum = types.Datum( + model_input=types.ModelInput.from_ints(input_tokens), + loss_fn_inputs={ + 'target_tokens': target_tokens, + 'weights': weights, + 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)), + 'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)), + }, + ) + training_data.append(datum) + + if not training_data: + log(f'[step {step + 1}] No training data, skipping') + step += 1 + continue + + # Step 6: forward_backward with importance_sampling + log(f'[step {step + 1}] forward_backward ({len(training_data)} datums)...') + t2 = time.time() + fwdbwd_result = training_client.forward_backward(training_data, 'importance_sampling').result() + elapsed_fb = time.time() - t2 + assert_no_timeout(elapsed_fb, f'grpo_tinker forward_backward step {step}') + log(f'[step {step + 1}] forward_backward OK ({elapsed_fb:.1f}s)') + + # Step 7: Optimizer step + optim_result = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result() + if optim_result.metrics: + assert_metrics_valid(optim_result.metrics, f'grpo_tinker step {step}') + + step += 1 + log(f'[step {step}] Complete. accuracy_rewards={accuracy_rewards[:4]}') + + assert step >= GRPO_TRAIN_STEPS, f'Expected {GRPO_TRAIN_STEPS} steps, completed {step}' + log(f'test_grpo_tinker PASSED (backend={backend})') + + +# ── Direct execution ── + +def main() -> int: + log('Running GRPO E2E tests directly...') + try: + test_grpo_twinkle() + test_grpo_tinker() + log('ALL GRPO TESTS PASSED') + return 0 + except Exception as e: + log(f'FAILED: {e}') + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/integration/test_sft_e2e.py b/tests/server/integration/test_sft_e2e.py new file mode 100644 index 000000000..4697337a1 --- /dev/null +++ b/tests/server/integration/test_sft_e2e.py @@ -0,0 +1,200 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""SFT (Supervised Fine-Tuning) E2E integration tests. + +Tests SFT training across all 4 combinations: + - Twinkle client x (transformers | megatron) + - Tinker client x (transformers | megatron) + +Backend selection via env var TWINKLE_TEST_BACKEND (default: transformers). + +## How to run + + # Start server (transformers or megatron) + python tests/server/start_e2e_server.py --config tests/server/config/server_config_4b_e2e.yaml + + # Run SFT tests + TWINKLE_TEST_GPU_E2E=1 TWINKLE_TEST_BACKEND=transformers pytest tests/server/integration/test_sft_e2e.py -v +""" +from __future__ import annotations + +import os +import sys +import time + +# Ensure project root is in sys.path for both pytest and direct execution +_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +from tests.server.integration.e2e_helpers import ( + BASE_URL, + GRADIENT_ACCUMULATION_STEPS, + TIMEOUT, + assert_loss_decreases, + assert_no_timeout, + convert_tensors, + create_sft_dataset, + create_tinker_training_client, + create_twinkle_sft_model, + get_backend, + init_twinkle_client_session, + log, + wait_for_server, +) + +# ── Configuration ── +SFT_TRAIN_STEPS = 20 # 20 steps ensures enough training for both backends + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: SFT via Twinkle client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_sft_twinkle(): + """SFT training via Twinkle client (MultiLoraTransformersModel). + + Pass criteria: + - Training completes 10 steps without timeout + - Loss shows downward trend (last_3_avg < first_3_avg) + """ + backend = get_backend() + log(f'=== test_sft_twinkle [backend={backend}] ===') + + wait_for_server() + init_twinkle_client_session() + + # Setup + from twinkle.dataloader import DataLoader + + dataset = create_sft_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=4) + model = create_twinkle_sft_model() + + log(f'Dataset: {len(dataset)} samples, {len(dataloader)} batches') + log(f'Training {SFT_TRAIN_STEPS} steps (GA={GRADIENT_ACCUMULATION_STEPS})') + + # Training loop + losses = [] + for step, batch in enumerate(dataloader): + if step >= SFT_TRAIN_STEPS: + break + + t0 = time.time() + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + elapsed = time.time() - t0 + assert_no_timeout(elapsed, f'sft_twinkle step {step}') + + # Log metric every GA steps + if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: + metric = model.calculate_metric(is_training=True) + try: + loss = float(metric.result.get('loss')) if hasattr(metric.result, 'get') else float( + metric.result['loss']) + except Exception: + loss = float('nan') + losses.append(loss) + log(f'[step {step + 1}] loss={loss:.4f} ({elapsed:.1f}s)') + + # For Megatron backend, calculate_metric(is_training=True) has a known server-side bug + # that always returns loss=0 (test_full_cycle_e2e.py also reproduces this). + # Loss verification for Megatron is done via test_sft_tinker (logprobs-based). + if backend == 'megatron': + log('[sft_twinkle] Megatron: calculate_metric server bug (loss=0), training completed OK') + log('[sft_twinkle] Loss decrease verified via test_sft_tinker (logprobs-based)') + else: + assert len(losses) >= 4, f'Expected at least 4 logged losses, got {len(losses)}' + assert_loss_decreases(losses, 'sft_twinkle') + log(f'test_sft_twinkle PASSED (backend={backend})') + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: SFT via Tinker client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_sft_tinker(): + """SFT training via Tinker client (ServiceClient + forward_backward). + + Pass criteria: + - Training completes 10 steps without timeout + - Loss shows downward trend (last_3_avg < first_3_avg) + """ + from tinker import types + from twinkle.dataloader import DataLoader + from twinkle.server.common import input_feature_to_datum + + backend = get_backend() + log(f'=== test_sft_tinker [backend={backend}] ===') + + wait_for_server() + + # Setup + dataset = create_sft_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=4) + training_client = create_tinker_training_client(rank=16) + + log(f'Dataset: {len(dataset)} samples, {len(dataloader)} batches') + log(f'Training {SFT_TRAIN_STEPS} steps') + + # Training loop + losses = [] + for step, batch in enumerate(dataloader): + if step >= SFT_TRAIN_STEPS: + break + + # Convert batch to Tinker Datums + input_datums = [input_feature_to_datum(input_feature) for input_feature in batch] + + # Forward-backward + t0 = time.time() + fwdbwd_result = training_client.forward_backward(input_datums, 'cross_entropy').result() + elapsed_fb = time.time() - t0 + assert_no_timeout(elapsed_fb, f'sft_tinker forward_backward step {step}') + + # Optimizer step + optim_result = training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result() + elapsed_total = time.time() - t0 + assert_no_timeout(elapsed_total, f'sft_tinker total step {step}') + + # Compute loss from logprobs + try: + logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) + weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datums]) + loss = float(-np.dot(logprobs, weights) / max(weights.sum(), 1e-8)) + except Exception: + loss = float('nan') + losses.append(loss) + log(f'[step {step + 1}] loss={loss:.4f} ({elapsed_total:.1f}s)') + + # Assertions + assert len(losses) >= 4, f'Expected at least 4 logged losses, got {len(losses)}' + assert_loss_decreases(losses, 'sft_tinker') + log(f'test_sft_tinker PASSED (backend={backend})') + + +# ── Direct execution ── + +def main() -> int: + log('Running SFT E2E tests directly...') + try: + test_sft_twinkle() + test_sft_tinker() + log('ALL SFT TESTS PASSED') + return 0 + except Exception as e: + log(f'FAILED: {e}') + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) From 70b2556798dbbbfed114cbea267fe57a7736b190 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 26 Jun 2026 16:08:41 +0800 Subject: [PATCH 14/16] update reg test --- src/twinkle/model/megatron/megatron.py | 11 ++++++----- tests/server/integration/e2e_helpers.py | 2 +- tests/server/integration/test_dpo_e2e.py | 13 ++++--------- tests/server/integration/test_sft_e2e.py | 12 +++--------- 4 files changed, 14 insertions(+), 24 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 76e0c8dba..dc34a1358 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -386,11 +386,12 @@ def post_loss_function(output_tensor, inputs, logps, unpacked_logits=None, entro losses = result['loss'] counts = result['num_tokens'] if not counts: - # safe_loss returned zero loss (num_tokens=0): use output_tensor - # to rebuild graph connectivity so backward triggers ALL gradient - # buckets, preventing DP AllReduce asymmetry in PP mode. - if output_tensor.requires_grad: - losses = (output_tensor.flatten()[:1] * 0).sum() + # num_tokens=0 covers two cases: + # 1. Normal mean-reduction loss (e.g. CrossEntropyLoss reduction='mean') + # → losses already contains the correct mean loss, just set counts=1. + # 2. safe_loss error degradation → losses is a graph-connected zero. + # Both cases need counts=1 so reduce_loss divides correctly. + # Do NOT overwrite losses here; case 1 would lose the real loss value. counts = torch.tensor(1, device=losses.device) return self.strategy.reduce_loss(losses, counts, output_tensor, logps) diff --git a/tests/server/integration/e2e_helpers.py b/tests/server/integration/e2e_helpers.py index aaaa87495..9d3073bad 100644 --- a/tests/server/integration/e2e_helpers.py +++ b/tests/server/integration/e2e_helpers.py @@ -74,7 +74,7 @@ def create_sft_dataset(data_slice=range(100)): dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=data_slice)) dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=256) dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) - dataset.encode(batched=True, load_from_cache_file=False) + dataset.encode(batched=True) return dataset diff --git a/tests/server/integration/test_dpo_e2e.py b/tests/server/integration/test_dpo_e2e.py index d66250a52..035ae188f 100644 --- a/tests/server/integration/test_dpo_e2e.py +++ b/tests/server/integration/test_dpo_e2e.py @@ -146,17 +146,14 @@ def test_dpo_twinkle(): assert step == DPO_TRAIN_STEPS, f'Expected {DPO_TRAIN_STEPS} steps, completed {step}' # Verify DPO loss decreases - backend = get_backend() - if len(losses) >= 3 and not all(l == 0.0 for l in losses): + if len(losses) >= 3: log(f'DPO losses: {["{:.4f}".format(l) for l in losses]}') assert losses[-1] < losses[0], ( f'[dpo_twinkle] DPO loss did NOT decrease: first={losses[0]:.4f} last={losses[-1]:.4f}') log(f'[dpo_twinkle] Loss decreased: {losses[0]:.4f} -> {losses[-1]:.4f}') - elif backend == 'megatron': - log('[dpo_twinkle] Megatron: loss reports 0 (known behavior), verifying training completed OK') # Verify reward margins increase (DPO learns to prefer chosen) - if len(reward_margins) >= 3 and not all(abs(r) < 1e-6 for r in reward_margins): + if len(reward_margins) >= 3: log(f'Reward margins: {["{:.4f}".format(r) for r in reward_margins]}') assert reward_margins[-1] > reward_margins[0], ( f'[dpo_twinkle] Reward margins did NOT increase: first={reward_margins[0]:.4f} last={reward_margins[-1]:.4f}') @@ -269,16 +266,14 @@ def test_dpo_tinker(): assert step == DPO_TRAIN_STEPS, f'Expected {DPO_TRAIN_STEPS} steps, completed {step}' # Verify DPO loss decreases - if len(losses) >= 3 and not all(l == 0.0 for l in losses): + if len(losses) >= 3: log(f'DPO losses: {["{:.4f}".format(l) for l in losses]}') assert losses[-1] < losses[0], ( f'[dpo_tinker] DPO loss did NOT decrease: first={losses[0]:.4f} last={losses[-1]:.4f}') log(f'[dpo_tinker] Loss decreased: {losses[0]:.4f} -> {losses[-1]:.4f}') - elif backend == 'megatron': - log('[dpo_tinker] Megatron: loss reports 0 (known behavior), verifying training completed OK') # Verify reward margins increase - if len(reward_margins) >= 3 and not all(abs(r) < 1e-6 for r in reward_margins): + if len(reward_margins) >= 3: log(f'Reward margins: {["{:.4f}".format(r) for r in reward_margins]}') assert reward_margins[-1] > reward_margins[0], ( f'[dpo_tinker] Reward margins did NOT increase: first={reward_margins[0]:.4f} last={reward_margins[-1]:.4f}') diff --git a/tests/server/integration/test_sft_e2e.py b/tests/server/integration/test_sft_e2e.py index 4697337a1..40c794b11 100644 --- a/tests/server/integration/test_sft_e2e.py +++ b/tests/server/integration/test_sft_e2e.py @@ -104,15 +104,9 @@ def test_sft_twinkle(): losses.append(loss) log(f'[step {step + 1}] loss={loss:.4f} ({elapsed:.1f}s)') - # For Megatron backend, calculate_metric(is_training=True) has a known server-side bug - # that always returns loss=0 (test_full_cycle_e2e.py also reproduces this). - # Loss verification for Megatron is done via test_sft_tinker (logprobs-based). - if backend == 'megatron': - log('[sft_twinkle] Megatron: calculate_metric server bug (loss=0), training completed OK') - log('[sft_twinkle] Loss decrease verified via test_sft_tinker (logprobs-based)') - else: - assert len(losses) >= 4, f'Expected at least 4 logged losses, got {len(losses)}' - assert_loss_decreases(losses, 'sft_twinkle') + # Assertions — both backends should report real loss via calculate_metric + assert len(losses) >= 4, f'Expected at least 4 logged losses, got {len(losses)}' + assert_loss_decreases(losses, 'sft_twinkle') log(f'test_sft_twinkle PASSED (backend={backend})') From b42cf06c1174298dec55a4c503dfb0715d62cc35 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 26 Jun 2026 17:05:40 +0800 Subject: [PATCH 15/16] update reg test --- src/twinkle/server/model/backends/megatron_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index fd42a29b3..21b10ee73 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -106,7 +106,7 @@ def forward_only(self, *, inputs: InputFeature | list[InputFeature] | Trajectory output = super().forward_only(inputs=inputs, **kwargs) return to_cpu_safe_output(output) - @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) @nccl_safe_megatron def forward_backward(self, *, inputs: InputFeature | list[InputFeature] | Trajectory | list[Trajectory], **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" From b80d14acf07e510cbf55347e0de48f3f50e83001 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 26 Jun 2026 17:46:57 +0800 Subject: [PATCH 16/16] revert comment --- src/twinkle/model/megatron/megatron.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index dc34a1358..7a94d4422 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -386,12 +386,14 @@ def post_loss_function(output_tensor, inputs, logps, unpacked_logits=None, entro losses = result['loss'] counts = result['num_tokens'] if not counts: - # num_tokens=0 covers two cases: - # 1. Normal mean-reduction loss (e.g. CrossEntropyLoss reduction='mean') - # → losses already contains the correct mean loss, just set counts=1. - # 2. safe_loss error degradation → losses is a graph-connected zero. - # Both cases need counts=1 so reduce_loss divides correctly. - # Do NOT overwrite losses here; case 1 would lose the real loss value. + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) + # = gradient_accumulation_steps * world_size + # Then, grad will divided by this value: + # 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) + # / (gradient_accumulation_steps * world_size ) = avg_per_token_grad counts = torch.tensor(1, device=losses.device) return self.strategy.reduce_loss(losses, counts, output_tensor, logps)