diff --git a/cookbook/client/server/megatron/entrypoint.sh b/cookbook/client/server/megatron/entrypoint.sh index 34a2bfa5e..879e3890c 100755 --- a/cookbook/client/server/megatron/entrypoint.sh +++ b/cookbook/client/server/megatron/entrypoint.sh @@ -11,10 +11,12 @@ TWINKLE_WORK_DIR="${TWINKLE_WORK_DIR:-/dashscope/caches/application/twinkle}" TEMP_DIR="${TWINKLE_TEMP_DIR:-/dashscope/caches/application/ray_logs}" LOG_FILE="$TWINKLE_WORK_DIR/run.log" TWINKLE_HEALTH_URL="${TWINKLE_HEALTH_URL:-http://127.0.0.1:9000/api/v1/healthz}" +TWINKLE_DEEP_HEALTH_URL="${TWINKLE_DEEP_HEALTH_URL:-http://127.0.0.1:9000/api/v1/twinkle/healthz/deep}" TWINKLE_WATCHDOG_INTERVAL_SECONDS="${TWINKLE_WATCHDOG_INTERVAL_SECONDS:-10}" TWINKLE_WATCHDOG_FAILURE_THRESHOLD="${TWINKLE_WATCHDOG_FAILURE_THRESHOLD:-3}" TWINKLE_RAY_GRACE_SECONDS="${TWINKLE_RAY_GRACE_SECONDS:-30}" TWINKLE_HEALTH_GRACE_SECONDS="${TWINKLE_HEALTH_GRACE_SECONDS:-${TWINKLE_WATCHDOG_STARTUP_GRACE_SECONDS:-300}}" +TWINKLE_DEEP_HEALTH_GRACE_SECONDS="${TWINKLE_DEEP_HEALTH_GRACE_SECONDS:-${TWINKLE_HEALTH_GRACE_SECONDS:-300}}" RESTART_BACKOFF_SECONDS="${TWINKLE_ENTRYPOINT_RESTART_BACKOFF_SECONDS:-10}" CHILD_PID="" @@ -58,6 +60,7 @@ validate_entrypoint_config() { require_positive_int "TWINKLE_WATCHDOG_FAILURE_THRESHOLD" "$TWINKLE_WATCHDOG_FAILURE_THRESHOLD" require_non_negative_int "TWINKLE_RAY_GRACE_SECONDS" "$TWINKLE_RAY_GRACE_SECONDS" require_non_negative_int "TWINKLE_HEALTH_GRACE_SECONDS" "$TWINKLE_HEALTH_GRACE_SECONDS" + require_non_negative_int "TWINKLE_DEEP_HEALTH_GRACE_SECONDS" "$TWINKLE_DEEP_HEALTH_GRACE_SECONDS" require_non_negative_int "TWINKLE_ENTRYPOINT_RESTART_BACKOFF_SECONDS" "$RESTART_BACKOFF_SECONDS" require_command timeout @@ -72,13 +75,14 @@ validate_entrypoint_config() { } check_http_health() { + local url="${1:-$TWINKLE_HEALTH_URL}" if command -v curl &> /dev/null; then - curl -fsS --max-time 10 "$TWINKLE_HEALTH_URL" >/dev/null + curl -fsS --max-time 10 "$url" >/dev/null return fi if command -v wget &> /dev/null; then - wget -q --spider --timeout=10 "$TWINKLE_HEALTH_URL" + wget -q -O /dev/null --timeout=10 "$url" return fi @@ -86,7 +90,7 @@ check_http_health() { if ! command -v "$python_bin" &> /dev/null; then python_bin="python" fi - "$python_bin" - "$TWINKLE_HEALTH_URL" <<'PY' + "$python_bin" - "$url" <<'PY' import sys import urllib.request @@ -102,6 +106,7 @@ PY print_watchdog_diagnostics() { print_warning "EntryPoint watchdog 诊断信息:" echo " - health url: $TWINKLE_HEALTH_URL" + echo " - deep health url: $TWINKLE_DEEP_HEALTH_URL" echo " - run.sh pid: ${CHILD_PID:-unset}" echo " - Ray logs: $TEMP_DIR/session_latest/logs" @@ -161,6 +166,9 @@ while true; do elif ! check_http_health; then WATCHDOG_FAILURE_REASON="http health check failed: $TWINKLE_HEALTH_URL" WATCHDOG_GRACE_SECONDS="$TWINKLE_HEALTH_GRACE_SECONDS" + elif ! check_http_health "$TWINKLE_DEEP_HEALTH_URL"; then + WATCHDOG_FAILURE_REASON="deep health check failed (model actors may be dead): $TWINKLE_DEEP_HEALTH_URL" + WATCHDOG_GRACE_SECONDS="$TWINKLE_DEEP_HEALTH_GRACE_SECONDS" fi if [ -z "$WATCHDOG_FAILURE_REASON" ]; then diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index c93680066..67ef5fdaa 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -54,6 +54,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" + TWINKLE_FAIL_FAST: "0" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -66,7 +67,7 @@ applications: nproc_per_node: 4 # Number of GPU processes per node sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) engine_args: # vLLM engine-specific settings - max_model_len: 65536 # Maximum sequence length the engine supports + max_model_len: 32768 # Maximum sequence length the engine supports gpu_memory_utilization: 0.75 # 80% utilization, ~64GB/GPU, leaves buffer for safety enable_lora: true # Allow loading LoRA adapters during inference max_loras: 5 # Max allowed loras working on vLLM at the same time @@ -84,7 +85,7 @@ applications: queue_config: rps_limit: 20 # Max requests per second tps_limit: 131072 # Max tokens per second - max_input_tokens: 65536 + max_input_tokens: 32768 deployments: - name: SamplerManagement autoscaling_config: @@ -97,6 +98,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" + TWINKLE_FAIL_FAST: "0" # 2. Model Service - Hosts the base model for training. # Config: PP=2 x DP=2 on 4 GPUs, ~27GB weights/GPU, comfortable for LoRA training @@ -105,8 +107,8 @@ applications: import_path: model args: backend: megatron # Use Megatron-LM backend - model_id: "ms://Qwen/Qwen3.6-27B" # ModelScope model identifier - max_length: 65536 # model max length + model_id: "ms://Qwen/Qwen3.6-27B" # ModelScope model identifier + max_length: 32768 # model max length max_loras: 3 # model max loras nproc_per_node: 4 # Number of GPU processes per node device_group: @@ -121,7 +123,7 @@ applications: queue_config: rps_limit: 20 # Max requests per second tps_limit: 131072 # Max tokens per second - max_input_tokens: 65536 + max_input_tokens: 32768 adapter_config: adapter_timeout: 120 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) @@ -137,6 +139,7 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" TWINKLE_LONG_POLL_TIMEOUT: "120" + TWINKLE_FAIL_FAST: "0" # 4. Processor Service - name: processor @@ -159,3 +162,6 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index 7eed4699d..9bdbd5e72 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -31,6 +31,9 @@ applications: target_ongoing_requests: 128 # Target concurrent requests per replica ray_actor_options: num_cpus: 0.1 # CPU resources allocated to this actor + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. @@ -68,6 +71,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_FAIL_FAST: "0" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -105,6 +109,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_FAIL_FAST: "0" # 4. Processor Service - name: processor @@ -127,3 +132,6 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 183bfff24..d3ddb2adb 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -11,7 +11,7 @@ http_options: # Telemetry: push traces/metrics/logs to LGTM's OTel Collector via OTLP telemetry: - enabled: true + enabled: false otlp_endpoint: http://localhost:4317 # Persistence configuration for ServerState (sessions, models, futures, ...). @@ -49,6 +49,9 @@ applications: target_ongoing_requests: 128 # Target concurrent requests per replica ray_actor_options: num_cpus: 0.1 # CPU resources allocated to this actor + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" # 2. Model Service - Hosts the base model for training. - name: models-Qwen3.5-4B @@ -81,43 +84,45 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "0" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - # - name: sampler-Qwen3.5-4B - # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - # nproc_per_node: 2 # Number of GPU processes per node - # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - # engine_args: # vLLM engine-specific settings - # max_model_len: 4096 # Maximum sequence length the engine supports - # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - # enable_lora: true # Allow loading LoRA adapters during inference - # logprobs_mode: processed_logprobs # Logprobs mode for sampling results - # device_group: # Logical device group for the sampler - # name: sampler - # ranks: 1 # Number of GPUs to use - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 # Max requests per second - # tps_limit: 100000 # Max tokens per second - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "0" + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + nproc_per_node: 1 # Number of GPU processes per node + sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + engine_args: # vLLM engine-specific settings + max_model_len: 4096 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results + device_group: # Logical device group for the sampler + name: sampler + ranks: 1 # Number of GPUs to use + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "0" # 4. Processor Service - name: processor @@ -140,3 +145,6 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" diff --git a/cookbook/client/server/transformer/server_e2e.py b/cookbook/client/server/transformer/server_e2e.py deleted file mode 100644 index b61f8d25e..000000000 --- a/cookbook/client/server/transformer/server_e2e.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Boot the e2e variant of the transformer server (model + sampler + processor).""" -import os - -os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0' - -from twinkle.server import launch_server - -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config_e2e.yaml') - -launch_server(config_path=config_path) diff --git a/cookbook/client/tinker/modelscope/dpo.py b/cookbook/client/tinker/modelscope/dpo.py index 23cf5aaee..cfbe916e9 100644 --- a/cookbook/client/tinker/modelscope/dpo.py +++ b/cookbook/client/tinker/modelscope/dpo.py @@ -51,7 +51,7 @@ max_length = 2048 lora_rank = 8 system_prompt = 'You are a helpful assistant.' -use_swanlab = True +use_swanlab = False # --------------------------------------------------------------------------- diff --git a/cookbook/client/tinker/modelscope/short_math_grpo.py b/cookbook/client/tinker/modelscope/short_math_grpo.py index bf57a9424..780df4448 100644 --- a/cookbook/client/tinker/modelscope/short_math_grpo.py +++ b/cookbook/client/tinker/modelscope/short_math_grpo.py @@ -165,7 +165,7 @@ def main(): if step % SYNC_INTERVAL == 0: logger.info(f'Step {step}: Saving weights for sampler...') - sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}')) + sampling_client = training_client.save_weights_and_get_sampling_client() logger.info(f'Step {step}: Sampling client ready') if sampling_client is None: diff --git a/cookbook/client/tinker/self_host/short_math_grpo.py b/cookbook/client/tinker/self_host/short_math_grpo.py index e19955a6c..f87fe812e 100644 --- a/cookbook/client/tinker/self_host/short_math_grpo.py +++ b/cookbook/client/tinker/self_host/short_math_grpo.py @@ -165,7 +165,7 @@ def main(): if step % SYNC_INTERVAL == 0: logger.info(f'Step {step}: Saving weights for sampler...') - sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}')) + sampling_client = training_client.save_weights_and_get_sampling_client() logger.info(f'Step {step}: Sampling client ready') if sampling_client is None: diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index b051546bb..7a94d4422 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -36,6 +36,7 @@ from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils import construct_class, get_logger, selective_log_softmax +from twinkle.utils.nccl_safe import _is_fail_fast from ._mindspeed_runtime import ensure_mindspeed_adaptor_patched from .strategy import MegatronStrategy @@ -407,6 +408,7 @@ def forward_step_func(data_iterator, model): output_tensor = model(**batch) else: output_tensor = model(**batch) + batch['labels'] = labels logps = None unpacked_logits = None @@ -414,44 +416,59 @@ def forward_step_func(data_iterator, model): embeddings = None _loss_instance = loss_instance is_last_pp = mpu.is_pipeline_last_stage(False, unwrapped_model.vp_stage) - if task == 'embedding': - # MegatronEmbeddingPatch already pooled output to [n_seqs, hidden] on last PP stage. - if is_last_pp: - embeddings = output_tensor - elif labels is not None and is_last_pp: - _loss_require_logps = getattr(_loss_instance, 'require_logps', True) - _loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) - _packed = batch.get('packed_seq_params') - cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None - if _loss_require_logps: - loss_mask = (labels != -100).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - output_tensor.div_(temperature) - if _loss_require_entropy: - logps, entropies = selective_log_softmax(output_tensor, masked_labels, return_entropy=True) - else: - logps = selective_log_softmax(output_tensor, masked_labels) - # Reconstruct full-length tensors from CP-split shards - logps = processor.postprocess_tensor_cp(logps, cu_seqlens=cu_seqlens_q) + try: + if task == 'embedding': + # MegatronEmbeddingPatch already pooled output to [n_seqs, hidden] on last PP stage. + if is_last_pp: + embeddings = output_tensor + elif labels is not None and is_last_pp: + _loss_require_logps = getattr(_loss_instance, 'require_logps', True) + _loss_require_entropy = ( + hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) + _packed = batch.get('packed_seq_params') + cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None + if _loss_require_logps: + loss_mask = (labels != -100).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + output_tensor.div_(temperature) + if _loss_require_entropy: + logps, entropies = selective_log_softmax(output_tensor, masked_labels, return_entropy=True) + else: + logps = selective_log_softmax(output_tensor, masked_labels) + # Reconstruct full-length tensors from CP-split shards + logps = processor.postprocess_tensor_cp(logps, cu_seqlens=cu_seqlens_q) + if entropies is not None: + entropies = processor.postprocess_tensor_cp(entropies, cu_seqlens=cu_seqlens_q) + batch['labels'] = processor.postprocess_tensor_cp(labels, cu_seqlens=cu_seqlens_q) + if 'position_ids' in batch: + pos = batch['position_ids'] + if pos.dim() == 3: + pos = pos[0] # [2/3, 1, seq] → [1, seq] + batch['position_ids'] = processor.postprocess_tensor_cp(pos, cu_seqlens=cu_seqlens_q) + # Unpack packed sequences into per-sequence batch format + _outputs = {'logps': logps} if entropies is not None: - entropies = processor.postprocess_tensor_cp(entropies, cu_seqlens=cu_seqlens_q) - batch['labels'] = processor.postprocess_tensor_cp(labels, cu_seqlens=cu_seqlens_q) - if 'position_ids' in batch: - pos = batch['position_ids'] - if pos.dim() == 3: - pos = pos[0] # [2/3, 1, seq] → [1, seq] - batch['position_ids'] = processor.postprocess_tensor_cp(pos, cu_seqlens=cu_seqlens_q) - # Unpack packed sequences into per-sequence batch format - _outputs = {'logps': logps} - if entropies is not None: - _outputs['entropies'] = entropies - if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits: - _outputs['logits'] = output_tensor - batch, _outputs = processor.unpack_packed_sequences(batch, _outputs) - logps = _outputs['logps'] - entropies = _outputs.get('entropies', None) - unpacked_logits = _outputs.get('logits', None) + _outputs['entropies'] = entropies + if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits: + _outputs['logits'] = output_tensor + batch, _outputs = processor.unpack_packed_sequences(batch, _outputs) + logps = _outputs['logps'] + entropies = _outputs.get('entropies', None) + unpacked_logits = _outputs.get('logits', None) + except Exception as e: + # Data processing error (e.g. unpack_packed_sequences dimension mismatch). + # Must catch here inside the scheduler to prevent exception escaping + # and breaking PP P2P communication → NCCL hang. + if _is_fail_fast(): + raise + logger.warning('[nccl_safe] forward_step_func data processing error: ' + '%s: %s', + type(e).__name__, e) + logps = None + unpacked_logits = None + entropies = None + embeddings = None return output_tensor, partial( post_loss_function, inputs=batch, diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 819014eb8..fae4e4241 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -257,8 +257,8 @@ def reduce_loss(self, local_loss, local_count, logits, logps): grad_count = (count // cp_size).clamp(min=1) if cp_size > 1 else count return local_loss, grad_count, { 'loss': local_loss.detach(), - 'logits': logits.detach(), - 'logps': logps.detach(), + 'logits': logits.detach() if logits is not None else None, + 'logps': logps.detach() if logps is not None else None, 'num_tokens': count } diff --git a/src/twinkle/model/optimizer_group.py b/src/twinkle/model/optimizer_group.py index 5fdb89f74..384dffe42 100644 --- a/src/twinkle/model/optimizer_group.py +++ b/src/twinkle/model/optimizer_group.py @@ -47,6 +47,12 @@ class BaseOptimizerGroup: _device_mesh: DeviceMesh = None _last_grad_norm: float = 0.0 + def __setattr__(self, name, value): + if name == 'loss_instance' and value is not None: + from twinkle.utils.nccl_safe import safe_loss + value = safe_loss(value) + super().__setattr__(name, value) + def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: if gradient_accumulation_steps is None: gradient_accumulation_steps = self.gradient_accumulation_steps diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 467355e51..63d80bb93 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -93,7 +93,7 @@ def prepare_outputs(self, inputs: List[InputFeature], **kwargs) -> Union[List[In return inputs[0] else: for _input in inputs: - if 'position_ids' in _input and _input['position_ids'].dim() > 2: + if 'position_ids' in _input and _input['position_ids'] is not None and _input['position_ids'].dim() > 2: # megatron needs 3, 1, N _input['position_ids'] = _input['position_ids'][1:] return inputs diff --git a/src/twinkle/server/gateway/twinkle_handlers.py b/src/twinkle/server/gateway/twinkle_handlers.py index 77cc67c62..c3de4d8a5 100644 --- a/src/twinkle/server/gateway/twinkle_handlers.py +++ b/src/twinkle/server/gateway/twinkle_handlers.py @@ -36,6 +36,44 @@ async def get_capacity_info( async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') + @app.get('/twinkle/healthz/deep') + async def healthz_deep( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> dict: + """Deep health check: verifies model actors are alive, not just the gateway. + + Returns 503 if any model deployment's actors are unreachable (e.g. OOM/SIGSEGV). + The entrypoint watchdog should poll this endpoint to detect silent failures. + """ + from fastapi.responses import JSONResponse + + results = {} + all_healthy = True + + for model in self.supported_models: + model_name = model.model_name + try: + resp = await self.proxy.proxy_request(request, 'healthz', model_name, 'model') + healthy = (resp.status_code == 200) + if not healthy: + all_healthy = False + results[model_name] = { + 'healthy': healthy, + 'status_code': resp.status_code, + } + except Exception as e: + all_healthy = False + results[model_name] = { + 'healthy': False, + 'detail': str(e), + } + + body = {'healthy': all_healthy, 'models': results} + if not all_healthy: + return JSONResponse(status_code=503, content=body) + return body + @app.get('/twinkle/get_server_capabilities', response_model=types.GetServerCapabilitiesResponse) async def get_server_capabilities( request: Request, diff --git a/src/twinkle/server/launcher/env_propagation.py b/src/twinkle/server/launcher/env_propagation.py index 4d4a7a474..3b0c9b956 100644 --- a/src/twinkle/server/launcher/env_propagation.py +++ b/src/twinkle/server/launcher/env_propagation.py @@ -21,6 +21,10 @@ 'TWINKLE_MODEL_ID_ALIASES', ) +# NCCL-safe env var keys: controls fault tolerance behavior in distributed +# training (safe_loss / @nccl_safe). Must reach model worker actors. +NCCL_SAFE_ENV_KEYS: tuple[str, ...] = ('TWINKLE_FAIL_FAST', ) + def build_telemetry_env_vars() -> dict[str, str]: """Collect telemetry env vars from ``os.environ`` for worker propagation.""" @@ -33,9 +37,15 @@ def build_persistence_env_vars() -> dict[str, str]: return {k: os.environ[k] for k in PERSISTENCE_ENV_KEYS if k in os.environ} +def build_nccl_safe_env_vars() -> dict[str, str]: + """Collect NCCL-safe env vars from ``os.environ`` for worker propagation.""" + return {k: os.environ[k] for k in NCCL_SAFE_ENV_KEYS if k in os.environ} + + def build_propagated_env_vars() -> dict[str, str]: """Aggregate all env vars that must reach Ray worker processes.""" merged: dict[str, str] = {} merged.update(build_telemetry_env_vars()) merged.update(build_persistence_env_vars()) + merged.update(build_nccl_safe_env_vars()) return merged diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 261e34de9..bae7b2a93 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -147,6 +147,21 @@ async def shutdown(self) -> None: except Exception: pass + def check_model_health(self) -> dict: + """Probe model actors liveness via a lightweight ping. + + Returns a dict with 'healthy' (bool) and 'detail' (str). + If the model actors are dead (e.g. OOM/SIGSEGV), the ping call + will raise RayActorError, signalling the watchdog to restart. + """ + try: + result = self.model.ping() + if result is True: + return {'healthy': True, 'detail': 'model actors alive'} + return {'healthy': False, 'detail': f'unexpected ping result: {result}'} + except Exception as e: + return {'healthy': False, 'detail': f'model actor unreachable: {e}'} + async def _cleanup_adapter(self, adapter_name: str) -> None: if self.get_resource_info(adapter_name): self.clear_resource_state(adapter_name) diff --git a/src/twinkle/server/model/backends/common.py b/src/twinkle/server/model/backends/common.py index 2ac4f3eec..7c90a4ed0 100644 --- a/src/twinkle/server/model/backends/common.py +++ b/src/twinkle/server/model/backends/common.py @@ -204,7 +204,19 @@ def _tensor_output_to_rows(value, seq_lens: list[int], *, kind: str) -> list[tor # Non-last PP stages can legitimately produce no logits/logps. return None elif not (isinstance(value, list) and all(isinstance(item, torch.Tensor) for item in value)): - tensors = [torch.as_tensor(value, dtype=torch.float32)] + # Handle ragged list[list] (e.g. logps after to_cpu_safe_output + # converted variable-length Tensors to nested Python lists). + # Flatten microbatch grouping → per-sample 1D lists → pad_and_stack. + if isinstance(value, list) and value and isinstance(value[0], (list, tuple)): + flat = [s for item in value for s in (item if isinstance(item[0], (list, tuple)) else [item])] + from twinkle.utils import pad_and_stack_tensors + tensors = [ + pad_and_stack_tensors([torch.tensor(s, dtype=torch.float32) for s in flat], + pad_value=0.0, + concat=False) + ] + else: + tensors = [torch.as_tensor(value, dtype=torch.float32)] tensors = [tensor.detach().cpu() for tensor in tensors] if len(tensors) == len(seq_lens) and all(tensor.dim() <= 1 or tensor.shape[0] == 1 for tensor in tensors): diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 35584508d..21b10ee73 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -13,6 +13,7 @@ from twinkle.server.common.datum import datum_to_input_feature, extract_rl_features_for_loss from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results, to_cpu_safe_output) +from twinkle.utils.nccl_safe import nccl_safe_megatron @remote_class(execute='all') @@ -23,6 +24,7 @@ class TwinkleCompatMegatronModel(MultiLoraMegatronModel, TwinkleCompatModelBase) """ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True) + @nccl_safe_megatron(tinker=True) def tinker_forward_backward(self, *, inputs: list[types.Datum], adapter_name: str, loss_fn: str, **kwargs): """Combined forward and backward pass.""" self._tinker_setup_loss(loss_fn, inputs, adapter_name, kwargs) @@ -43,6 +45,7 @@ def tinker_forward_backward(self, *, inputs: list[types.Datum], adapter_name: st return [results, loss] @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) + @nccl_safe_megatron(tinker=True) def tinker_forward_only(self, *, inputs: list[types.Datum], adapter_name: str = None, **kwargs): """Forward pass without gradient computation.""" template = self.get_template(adapter_name) @@ -97,13 +100,34 @@ def tinker_load(self, checkpoint_dir: str, **kwargs): # ------------------------------------------------------------------ @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + @nccl_safe_megatron(forward_only=True) def forward_only(self, *, inputs: InputFeature | list[InputFeature] | Trajectory | list[Trajectory], **kwargs): """Forward-only for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_only(inputs=inputs, **kwargs) return to_cpu_safe_output(output) - @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True) + @nccl_safe_megatron def forward_backward(self, *, inputs: InputFeature | list[InputFeature] | Trajectory | list[Trajectory], **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" + # Normalize ragged ref_outputs logps into a regular 2D tensor. + # After HTTP + collect_tensor_dict, logps is a nested list grouped + # by microbatch with varying seq_lens across DP ranks. Flatten to + # per-sample 1D lists and pad_and_stack — same as datum.py L84-88. + ref_outputs = kwargs.get('ref_outputs') + if isinstance(ref_outputs, dict) and 'logps' in ref_outputs: + logps = ref_outputs['logps'] + if isinstance(logps, (list, tuple)) and logps and not isinstance(logps[0], torch.Tensor): + # Flatten [[mb0_sample0, mb0_sample1], [mb1_sample0, ...]] → [sample0, sample1, ...] + flat = [s for item in logps for s in (item if isinstance(item[0], (list, tuple)) else [item])] + from twinkle.utils import pad_and_stack_tensors + ref_outputs['logps'] = pad_and_stack_tensors([torch.tensor(s, dtype=torch.float32) for s in flat], + pad_value=0.0, + concat=False) output = super().forward_backward(inputs=inputs, **kwargs) return to_cpu_safe_output(output) + + @remote_function(collect='first', lazy_collect=False) + def ping(self) -> bool: + """Lightweight liveness probe for watchdog health checks.""" + return True diff --git a/src/twinkle/server/model/backends/mock_model.py b/src/twinkle/server/model/backends/mock_model.py index ac89935c0..bc0fc06f2 100644 --- a/src/twinkle/server/model/backends/mock_model.py +++ b/src/twinkle/server/model/backends/mock_model.py @@ -235,6 +235,11 @@ def remove_adapter(self, adapter_name: str) -> None: def has_adapter(self, adapter_name: str) -> bool: return adapter_name in self._adapters + @remote_function(collect='first', lazy_collect=False) + def ping(self) -> bool: + """Lightweight liveness probe for watchdog health checks.""" + return True + def _to_tinker_loss_outputs(records: list[dict[str, Any]]) -> list[dict[str, Any]]: """Wrap mock's plain numpy-derived dicts into ``tinker.TensorData`` instances. diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index 0e8730833..e7677b619 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -16,6 +16,7 @@ from twinkle.server.common.datum import datum_to_input_feature, extract_rl_features_for_loss from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results, to_cpu_safe_output) +from twinkle.utils.nccl_safe import nccl_safe @remote_class() @@ -40,6 +41,7 @@ def tinker_forward_only(self, *, inputs: list[types.Datum], adapter_name: str = return [results, 0.0] @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) + @nccl_safe(tinker=True) def tinker_forward_backward(self, *, inputs: list[types.Datum], adapter_name: str, loss_fn: str, **kwargs): self._tinker_setup_loss(loss_fn, inputs, adapter_name, kwargs) template = self.get_template(adapter_name) @@ -98,7 +100,13 @@ def forward_only(self, *, inputs: InputFeature | list[InputFeature] | Trajectory return to_cpu_safe_output(output) @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) + @nccl_safe def forward_backward(self, *, inputs: InputFeature | list[InputFeature] | Trajectory | list[Trajectory], **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) return to_cpu_safe_output(output) + + @remote_function(collect='first', lazy_collect=False) + def ping(self) -> bool: + """Lightweight liveness probe for watchdog health checks.""" + return True diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index b6973db04..1ddc87489 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -272,11 +272,15 @@ async def _do_save_for_sampler(): if metadata.get('base_model'): payload['base_model'] = metadata['base_model'] sampling_session_id = await self.state.create_sampling_session(payload) - # Return ``tinker_path`` (not None): tinker SDK's - # ``_save_weights_for_sampler_async`` asserts ``result.path is not None``. - # ``sampling_session_id`` is still the canonical handle. - return types.SaveWeightsForSamplerResponseInternal( - path=tinker_path, sampling_session_id=sampling_session_id) + # Tinker SDK distinguishes two modes by whether sampling_session_seq_id is set: + # 1. save_weights_for_sampler(name): expects path != None + # 2. save_weights_and_get_sampling_client(): expects path == None, sampling_session_id != None + if body.sampling_session_seq_id is not None: + return types.SaveWeightsForSamplerResponseInternal( + path=None, sampling_session_id=sampling_session_id) + else: + return types.SaveWeightsForSamplerResponseInternal( + path=tinker_path, sampling_session_id=sampling_session_id) except Exception: logger.error(traceback.format_exc()) return types.RequestFailedResponse( diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 5f1fc02d8..2074f5e40 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -62,6 +62,18 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], ModelManagement replica instance. It is wired in via Depends so it is resolved lazily at request time. """ + @app.get('/healthz') + async def model_healthz( + request: Request, + self: ModelManagement = Depends(self_fn), + ) -> dict: + """Deep health probe: pings underlying model actors to verify liveness.""" + result = self.check_model_health() + if not result['healthy']: + from fastapi.responses import JSONResponse + return JSONResponse(status_code=503, content=result) + return result + async def run_task(coro): """Await a schedule_task_and_wait coroutine and surface any exception as a structured HTTP 500 response so the client receives the full traceback instead diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index b222eb9be..ef779cd27 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -32,7 +32,7 @@ async def verify_request_token(request: Request, call_next): return JSONResponse(status_code=403, content={'detail': 'Invalid token'}) path = request.url.path - skip_sticky = (path.endswith('/healthz') or any(path.endswith(s) for s in _OPENAI_COMPAT_SUFFIXES)) + skip_sticky = ('/healthz' in path or any(path.endswith(s) for s in _OPENAI_COMPAT_SUFFIXES)) if not skip_sticky: request_id = request.headers.get(H_REQUEST_ID) diff --git a/src/twinkle/utils/nccl_safe.py b/src/twinkle/utils/nccl_safe.py new file mode 100644 index 000000000..8b9740d45 --- /dev/null +++ b/src/twinkle/utils/nccl_safe.py @@ -0,0 +1,332 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""NCCL-safe utilities for production distributed training. + +Provides three layers of protection to prevent NCCL hangs: + +Layer 1 - safe_loss(): + Wraps loss instances to catch computation errors and return + graph-connected zero loss (ensures FSDP ReduceScatter can proceed). + +Layer 2 - @nccl_safe decorator: + Wraps forward_backward methods to ensure backward() always executes + after forward() has started, even if intermediate code raises. + +Layer 3 - @nccl_safe_megatron decorator: + Wraps Megatron backend methods (forward_only, forward_backward) where + the entire function body involves NCCL communication (sync=True). + Catches pre-communication errors (e.g. data preprocessing failures) + that would otherwise leave other DP ranks waiting at a collective. + +Controlled by environment variable: + TWINKLE_FAIL_FAST=1 (default, development): all protection is transparent, + exceptions propagate normally. + TWINKLE_FAIL_FAST=0 (production): protection activated, exceptions in + NCCL-critical sections are caught and handled gracefully. +""" +import functools +import os + +from twinkle.data_format import LossOutput +from twinkle.loss import Loss +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +def _is_fail_fast() -> bool: + """Check if fail-fast mode is enabled (default: enabled). + + Returns True (fail-fast/development mode) unless TWINKLE_FAIL_FAST + is explicitly set to a falsy value. + """ + val = os.getenv('TWINKLE_FAIL_FAST', '1').upper() + return val not in ('0', 'NO', 'FALSE', 'OFF') + + +# ─── Layer 1: safe_loss ──────────────────────────────────────────────────── + + +def safe_loss(loss_instance): + """Wrap loss instance for production graceful degradation. + + Always wraps the loss instance (idempotent). The fail-fast check is deferred + to call time so that TWINKLE_FAIL_FAST can be set after wrapping (e.g. in + Ray actor processes where env vars may not be inherited from the launcher). + + When TWINKLE_FAIL_FAST=1 (default, development): wrapper is transparent, + exceptions propagate normally. + When TWINKLE_FAIL_FAST=0 (production): wrapper catches exceptions and + returns a graph-connected zero loss (ensures FSDP ReduceScatter proceeds). + + Idempotent: already-wrapped instances are returned as-is. + """ + if getattr(loss_instance, '_nccl_safe_wrapped', False): + return loss_instance + return SafeLossWrapper(loss_instance) + + +class SafeLossWrapper(Loss): + """Loss subclass that catches computation errors and returns graph-connected zero loss. + + Inherits from :class:`twinkle.loss.Loss` so ``isinstance(wrapper, Loss)`` + assertions in the training pipeline continue to pass. + """ + + def __init__(self, loss_instance): + super().__init__() + self._loss_instance = loss_instance + self.require_logps = getattr(loss_instance, 'require_logps', True) + self.require_entropy = getattr(loss_instance, 'require_entropy', False) + self.require_logits = getattr(loss_instance, 'require_logits', False) + self._nccl_safe_wrapped = True + + def __call__(self, inputs, outputs, **kwargs): + if _is_fail_fast(): + return self._loss_instance(inputs, outputs, **kwargs) + try: + return self._loss_instance(inputs, outputs, **kwargs) + except Exception as e: + import traceback + logger.warning('[nccl_safe] Loss computation skipped due to error: ' + '%s: %s\n%s', + type(e).__name__, e, traceback.format_exc()) + return _zero_loss(outputs) + + +def _zero_loss(outputs) -> 'LossOutput': + """Create a graph-connected zero loss for FSDP compatibility. + + Finds a gradient-bearing tensor from outputs to maintain graph connectivity, + ensuring backward hooks (ReduceScatter) fire. + """ + import torch + if isinstance(outputs, dict): + for key in ('logps', 'logits', 'loss'): + t = outputs.get(key) + if t is not None and isinstance(t, torch.Tensor) and t.requires_grad: + return LossOutput(loss=(t.flatten()[:1] * 0).sum(), num_tokens=0) + # Fallback: standalone zero tensor (may not trigger FSDP hooks) + device = 'cpu' + if isinstance(outputs, dict): + for v in outputs.values(): + if hasattr(v, 'device'): + device = v.device + break + return LossOutput(loss=torch.zeros((), device=device, requires_grad=True), num_tokens=0) + + +# ─── Layer 2: @nccl_safe decorator ────────────────────────────────────────── + + +def nccl_safe(func=None, *, tinker=False): + """Decorator ensuring backward() executes if forward() has already run. + + Detects forward completion by comparing train_status.outputs before/after + the wrapped function call. If an exception occurs after forward has run + but before backward completes, forces a zero-gradient backward pass to + prevent NCCL hang (other ranks waiting for ReduceScatter). + + Args: + func: The function to decorate (when used without arguments). + tinker: If True, fallback returns ``[[], 0.0]`` (tinker format). + If False, fallback returns outputs dict with ``loss=0.0``. + + Usage:: + + @remote_function(dispatch='slice_dp', collect=...) + @nccl_safe(tinker=True) + def tinker_forward_backward(self, *, inputs, adapter_name, ...): + # method body completely unchanged + ... + + @remote_function(dispatch='slice_dp', collect=...) + @nccl_safe + def forward_backward(self, *, inputs, **kwargs): + # method body completely unchanged + ... + """ + + def decorator(fn): + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + if _is_fail_fast(): + return fn(self, *args, **kwargs) + + # Extract adapter_name for state tracking + adapter_name = kwargs.get('adapter_name') + if adapter_name is None and hasattr(self, '_get_default_group'): + adapter_name = self._get_default_group() + + og = self.optimizer_group.get(adapter_name) if adapter_name else None + if og is None: + # Cannot track state without optimizer group, passthrough + return fn(self, *args, **kwargs) + + # Snapshot state before call to detect forward completion + outputs_before = og.train_status.outputs + + try: + return fn(self, *args, **kwargs) + except Exception as e: + outputs_after = og.train_status.outputs + forward_ran = (outputs_after is not None and outputs_after is not outputs_before) + + if not forward_ran: + # Pre-forward failure: no NCCL ops started, safe to propagate + raise + + # Forward completed. Check if backward already ran. + # TransformersModel.backward() clears loss_value to None. + backward_done = (og.train_status.loss_value is None) + + if backward_done: + # Post-backward failure (e.g. output formatting) + # No NCCL hang risk, just return gracefully + logger.warning(f'[nccl_safe] Post-backward error (no NCCL risk): ' + f'{type(e).__name__}: {e}') + else: + # CRITICAL: forward ran but backward didn't → NCCL hang risk! + logger.warning(f'[nccl_safe] Forcing zero backward to prevent NCCL hang: ' + f'{type(e).__name__}: {e}') + _force_zero_backward(self, og, adapter_name, kwargs) + + # Return fallback result + if tinker: + return [[], 0.0] + outputs_after['loss'] = 0.0 + return outputs_after + + return wrapper + + if func is not None: + # @nccl_safe without arguments + return decorator(func) + # @nccl_safe(tinker=True) with arguments + return decorator + + +def _iter_model_params(model): + """Iterate parameters from ``model.model``, supporting single model or list of models.""" + raw_model = getattr(model, 'model', None) + if raw_model is None: + return iter([]) + if isinstance(raw_model, (list, tuple)): + for m in raw_model: + yield from m.parameters() + else: + yield from raw_model.parameters() + + +def _force_zero_backward(model, og, adapter_name, kwargs): + """Force a zero-gradient backward pass to prevent NCCL hang. + + Creates a graph-connected zero loss tensor and calls backward(), + ensuring FSDP ReduceScatter hooks fire on all ranks. + """ + import torch + + outputs = og.train_status.outputs + + # Find a graph-connected tensor for zero loss + zero_loss = None + if outputs is not None and isinstance(outputs, dict): + for key in ('logps', 'logits', 'loss'): + t = outputs.get(key) + if t is not None and isinstance(t, torch.Tensor) and t.requires_grad: + zero_loss = (t.flatten()[:1] * 0).sum() + break + + if zero_loss is None: + # Fallback: use first model parameter to maintain graph connectivity. + # Do NOT detach() the parameter -- the zero loss must remain connected + # to the model's autograd graph so FSDP ReduceScatter hooks fire. + # Use lazy iteration to avoid materializing the full parameter list. + try: + param = next((p for p in _iter_model_params(model) if p.requires_grad), None) + if param is not None: + zero_loss = (param.flatten()[0] * 0).sum() + else: + zero_loss = torch.zeros((), device='cuda', requires_grad=True) + except Exception: + zero_loss = torch.zeros((), device='cuda', requires_grad=True) + + og.train_status.loss_value = zero_loss + + # Call backward with minimal kwargs + bwd_kwargs = {'adapter_name': adapter_name} + gas = kwargs.get('gradient_accumulation_steps') + if gas is not None: + bwd_kwargs['gradient_accumulation_steps'] = gas + model.backward(**bwd_kwargs) + + +# ─── Layer 3: @nccl_safe_megatron decorator ────────────────────────────────── + + +def nccl_safe_megatron(func=None, *, tinker=False, forward_only=False): + """Decorator for Megatron backend methods where the entire body is NCCL-critical. + + Unlike @nccl_safe (which detects forward/backward boundaries), this decorator + treats the **entire function** as a NCCL-critical section. In Megatron, + forward_only and forward_backward both call get_forward_backward_func() which + requires all DP ranks to enter synchronously. If one rank fails during data + preprocessing (before entering Megatron's scheduler), other ranks will hang + waiting for the collective. + + This decorator catches ALL exceptions (when TWINKLE_FAIL_FAST=0) and returns + a safe fallback value, preventing NCCL hang from asymmetric failures. + + Args: + func: The function to decorate (when used without arguments). + tinker: If True, fallback returns ``[[], 0.0]`` (tinker format). + forward_only: If True, fallback returns empty dict ``{}`` (forward_only format). + + Usage:: + + @remote_function(dispatch='slice_dp', collect=..., sync=True) + @nccl_safe_megatron + def forward_backward(self, *, inputs, **kwargs): + ... + + @remote_function(dispatch='slice_dp', collect=...) + @nccl_safe_megatron(forward_only=True) + def forward_only(self, *, inputs, **kwargs): + ... + + @remote_function(dispatch='slice_dp', collect=..., sync=True) + @nccl_safe_megatron(tinker=True) + def tinker_forward_backward(self, *, inputs, **kwargs): + ... + """ + + def decorator(fn): + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + if _is_fail_fast(): + return fn(self, *args, **kwargs) + + try: + return fn(self, *args, **kwargs) + except Exception as e: + import traceback + logger.warning(f'[nccl_safe_megatron] Exception in Megatron method ' + f'{fn.__name__}: {type(e).__name__}: {e}\n' + f'{traceback.format_exc()}') + + # Return safe fallback to prevent NCCL hang on other ranks + if tinker: + return [[], 0.0] + if forward_only: + return {} + # forward_backward fallback: return dict with loss=0.0 + return {'loss': 0.0} + + return wrapper + + if func is not None: + # @nccl_safe_megatron without arguments + return decorator(func) + # @nccl_safe_megatron(tinker=True) with arguments + return decorator diff --git a/cookbook/client/server/transformer/server_config_e2e.yaml b/tests/server/config/server_config_4b_e2e.yaml similarity index 83% rename from cookbook/client/server/transformer/server_config_e2e.yaml rename to tests/server/config/server_config_4b_e2e.yaml index 8c3b7cf05..6a9dfd317 100644 --- a/cookbook/client/server/transformer/server_config_e2e.yaml +++ b/tests/server/config/server_config_4b_e2e.yaml @@ -1,17 +1,10 @@ -# E2E test config: model + sampler + processor all enabled. -# Used only by full_cycle_e2e.py (client) + server_e2e.py (server boot). -# Mirror of server_config.yaml with the sampler block uncommented and a -# smaller persistence path so the e2e run doesn't clobber the demo state. +# Twinkle Server Configuration - E2E Test (4B model, Transformers backend) proxy_location: EveryNode http_options: host: 0.0.0.0 - port: 8000 - -persistence: - mode: file - file_path: /tmp/twinkle_state_e2e.json + port: 9000 applications: @@ -20,7 +13,7 @@ applications: import_path: server args: server_config: - per_token_model_limit: 3 + per_token_model_limit: 30 supported_models: - Qwen/Qwen3.5-4B deployments: @@ -32,6 +25,9 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" - name: models-Qwen3.5-4B route_prefix: /api/v1/model/Qwen/Qwen3.5-4B @@ -40,19 +36,20 @@ applications: backend: transformers model_id: "ms://Qwen/Qwen3.5-4B" max_length: 10240 - nproc_per_node: 1 + nproc_per_node: 2 device_group: name: model - ranks: 1 + ranks: 2 device_type: cuda device_mesh: device_type: cuda - dp_size: 1 + dp_size: 2 queue_config: rps_limit: 100 tps_limit: 100000 adapter_config: - adapter_timeout: 60 + adapter_timeout: 30 + adapter_max_lifetime: 36000 deployments: - name: ModelManagement autoscaling_config: @@ -64,6 +61,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "0" - name: sampler-Qwen3.5-4B route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B @@ -98,6 +96,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "0" - name: processor route_prefix: /api/v1/processor @@ -105,7 +104,7 @@ applications: args: ncpu_proc_per_node: 2 device_group: - name: processor + name: model ranks: 2 device_type: CPU device_mesh: @@ -119,3 +118,6 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" diff --git a/tests/server/config/server_config_4b_e2e_megatron.yaml b/tests/server/config/server_config_4b_e2e_megatron.yaml new file mode 100644 index 000000000..a6f193501 --- /dev/null +++ b/tests/server/config/server_config_4b_e2e_megatron.yaml @@ -0,0 +1,123 @@ +# Twinkle Server Configuration - E2E Test (4B model, Megatron DP=2 PP=2) + +proxy_location: EveryNode + +http_options: + host: 0.0.0.0 + port: 9000 + +applications: + + - name: server + route_prefix: /api/v1 + import_path: server + args: + server_config: + per_token_model_limit: 30 + supported_models: + - Qwen/Qwen3.5-4B + deployments: + - name: TinkerCompatServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" + + - name: models-Qwen3.5-4B + route_prefix: /api/v1/model/Qwen/Qwen3.5-4B + import_path: model + args: + backend: megatron + model_id: "ms://Qwen/Qwen3.5-4B" + max_length: 10240 + nproc_per_node: 4 + device_group: + name: model + ranks: 4 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 + pp_size: 2 + queue_config: + rps_limit: 100 + tps_limit: 100000 + adapter_config: + adapter_timeout: 30 + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "0" + + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" + nproc_per_node: 1 + sampler_type: vllm + engine_args: + max_model_len: 4096 + gpu_memory_utilization: 0.5 + enable_lora: true + logprobs_mode: processed_logprobs + device_group: + name: sampler + ranks: 1 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "0" + + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "0" diff --git a/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml new file mode 100644 index 000000000..64af5194d --- /dev/null +++ b/tests/server/config/server_config_4b_e2e_megatron_failfast1.yaml @@ -0,0 +1,124 @@ +# Twinkle Server Configuration - E2E Test (4B model, Megatron DP=2 PP=2, FAIL_FAST=1) + +proxy_location: EveryNode + +http_options: + host: 0.0.0.0 + port: 9000 + +applications: + + - name: server + route_prefix: /api/v1 + import_path: server + args: + server_config: + per_token_model_limit: 20 + supported_models: + - Qwen/Qwen3.5-4B + deployments: + - name: TinkerCompatServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" + + - name: models-Qwen3.5-4B + route_prefix: /api/v1/model/Qwen/Qwen3.5-4B + import_path: model + args: + backend: megatron + model_id: "ms://Qwen/Qwen3.5-4B" + max_length: 10240 + nproc_per_node: 4 + device_group: + name: model + ranks: 4 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 + pp_size: 2 + queue_config: + rps_limit: 100 + tps_limit: 100000 + adapter_config: + adapter_timeout: 30 + adapter_max_lifetime: 36000 + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" + + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" + nproc_per_node: 1 + sampler_type: vllm + engine_args: + max_model_len: 4096 + gpu_memory_utilization: 0.5 + enable_lora: true + logprobs_mode: processed_logprobs + device_group: + name: sampler + ranks: 1 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "1" + TWINKLE_FAIL_FAST: "1" + + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_FAIL_FAST: "1" diff --git a/tests/server/integration/e2e_helpers.py b/tests/server/integration/e2e_helpers.py new file mode 100644 index 000000000..9d3073bad --- /dev/null +++ b/tests/server/integration/e2e_helpers.py @@ -0,0 +1,271 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Shared utilities for E2E integration tests. + +Provides reusable helpers for server health check, client initialization, +dataset preparation, and test assertions across all 12 test combinations +(2 backends x 2 clients x 3 tasks). +""" +from __future__ import annotations + +import math +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +# ═══════════════════════════════════════════════════════════════════════════ +# Constants +# ═══════════════════════════════════════════════════════════════════════════ + +BASE_MODEL = 'Qwen/Qwen3.5-4B' +MODEL_ID = f'ms://{BASE_MODEL}' +BASE_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +API_KEY = 'EMPTY_API_KEY' +TIMEOUT = 120 # seconds per operation before declaring hang +GRADIENT_ACCUMULATION_STEPS = 2 # Megatron requires GA >= 2 + + +def get_backend() -> str: + """Read backend type from environment variable.""" + backend = os.environ.get('TWINKLE_TEST_BACKEND', 'transformers').lower() + assert backend in ('transformers', 'megatron'), ( + f'Invalid TWINKLE_TEST_BACKEND={backend!r}, must be transformers or megatron') + return backend + + +# ═══════════════════════════════════════════════════════════════════════════ +# Server Health Check +# ═══════════════════════════════════════════════════════════════════════════ + +def wait_for_server(url: str = BASE_URL, timeout: int = 300) -> None: + """Wait for Twinkle server to become ready using Python requests.""" + import requests + + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f'{url}/-/routes', timeout=5) + if resp.status_code == 200: + elapsed = int(time.time() - start) + log(f'Server ready (waited {elapsed}s)') + return + except (requests.ConnectionError, requests.Timeout): + pass + time.sleep(5) + raise TimeoutError(f'Server not ready after {timeout}s at {url}') + + +def log(msg: str) -> None: + """Timestamped log output.""" + ts = time.strftime('%H:%M:%S') + print(f'[{ts}][E2E] {msg}', flush=True) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Dataset Factories +# ═══════════════════════════════════════════════════════════════════════════ + +def create_sft_dataset(data_slice=range(100)): + """Create SelfCognition SFT dataset (small slice for speed).""" + from twinkle.dataloader import DataLoader + from twinkle.dataset import Dataset, DatasetMeta + + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=data_slice)) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=256) + dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) + dataset.encode(batched=True) + return dataset + + +def create_dpo_dataset(data_slice=range(50)): + """Create EmojiDPO dataset with positive/negative format.""" + from twinkle.dataset import Dataset, DatasetMeta + from twinkle.preprocessor import EmojiDPOProcessor + + dataset = Dataset(DatasetMeta('ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji', data_slice=data_slice)) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=1024) + dataset.map(EmojiDPOProcessor, init_args={'system': 'You are a helpful assistant.'}) + dataset.encode() + return dataset + + +def create_grpo_dataset(data_slice=range(50)): + """Create GSM8K dataset for GRPO training.""" + from twinkle.dataset import Dataset, DatasetMeta + from twinkle.preprocessor.llm import GSM8KProcessor + + system_prompt = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=data_slice)) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048, enable_thinking=False) + dataset.map(GSM8KProcessor(system=system_prompt)) + dataset.encode(add_generation_prompt=True) + return dataset + + +# ═══════════════════════════════════════════════════════════════════════════ +# Twinkle Client Factories +# ═══════════════════════════════════════════════════════════════════════════ + +def init_twinkle_client_session(): + """Initialize the Twinkle client session.""" + from twinkle import init_twinkle_client + return init_twinkle_client(base_url=BASE_URL, api_key=API_KEY) + + +def create_twinkle_sft_model(): + """Create Twinkle model configured for SFT (CrossEntropyLoss).""" + from peft import LoraConfig + from twinkle_client.model import MultiLoraTransformersModel + + model = MultiLoraTransformersModel(model_id=MODEL_ID) + model.add_adapter_to_model( + 'default', + LoraConfig(target_modules='all-linear'), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + model.set_loss('CrossEntropyLoss') + model.set_optimizer('Adam', lr=1e-4) + return model + + +def create_twinkle_dpo_model(): + """Create Twinkle model configured for DPO training.""" + from peft import LoraConfig + from twinkle_client.model import MultiLoraTransformersModel + + model = MultiLoraTransformersModel(model_id=MODEL_ID) + model.add_adapter_to_model( + 'default', + LoraConfig(target_modules='all-linear', r=8, lora_alpha=32, lora_dropout=0.05), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + model.set_loss('DPOLoss', beta=0.1, loss_type='sigmoid', reference_free=False, sft_weight=1.0) + model.add_metric('DPOMetric', beta=0.1) + model.set_optimizer('Adam', lr=1e-4) + return model + + +def create_twinkle_grpo_model(): + """Create Twinkle model configured for GRPO training.""" + from peft import LoraConfig + from twinkle_client.model import MultiLoraTransformersModel + + model = MultiLoraTransformersModel(model_id=MODEL_ID) + model.add_adapter_to_model( + 'default', + LoraConfig(target_modules='all-linear', r=8, lora_alpha=32, lora_dropout=0.05), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + model.set_loss('GRPOLoss', epsilon=0.2, beta=0.0) + model.set_optimizer('Adam', lr=2e-5) + model.set_processor('InputProcessor') + model.set_template('Qwen3_5Template', model_id=MODEL_ID) + return model + + +def create_twinkle_sampler(): + """Create Twinkle vLLM sampler.""" + from twinkle_client.sampler import vLLMSampler + + sampler = vLLMSampler(model_id=MODEL_ID) + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID) + return sampler + + +# ═══════════════════════════════════════════════════════════════════════════ +# Tinker Client Factories +# ═══════════════════════════════════════════════════════════════════════════ + +def init_tinker_client_session(): + """Initialize the Tinker client session and return ServiceClient.""" + from twinkle import init_tinker_client + init_tinker_client() + from tinker import ServiceClient + return ServiceClient(base_url=BASE_URL, api_key=API_KEY) + + +def create_tinker_training_client(rank: int = 8): + """Create Tinker LoRA training client.""" + service_client = init_tinker_client_session() + training_client = service_client.create_lora_training_client( + base_model=BASE_MODEL, + rank=rank, + ) + return training_client + + +# ═══════════════════════════════════════════════════════════════════════════ +# Data Processing Utilities +# ═══════════════════════════════════════════════════════════════════════════ + +def convert_tensors(batch: List[Dict[str, Any]]) -> None: + """Convert numpy/torch tensors to Python lists in-place for serialization.""" + import torch + + for row in batch: + for key in list(row.keys()): + val = row[key] + if isinstance(val, np.ndarray): + row[key] = val.tolist() + elif isinstance(val, torch.Tensor): + row[key] = val.cpu().numpy().tolist() + elif isinstance(val, dict): + for k2, v2 in val.items(): + if isinstance(v2, np.ndarray): + val[k2] = v2.tolist() + elif isinstance(v2, torch.Tensor): + val[k2] = v2.cpu().numpy().tolist() + + +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Reorganize batch into DP-safe interleaved format [pos_1, neg_1, pos_2, neg_2, ...].""" + result = [] + for row in batch: + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + pos_sample = {**base_fields, **row['positive']} + neg_sample = {**base_fields, **row['negative']} + result.append(pos_sample) + result.append(neg_sample) + return result + + +# ═══════════════════════════════════════════════════════════════════════════ +# Assertions / Pass Criteria +# ═══════════════════════════════════════════════════════════════════════════ + +def assert_no_timeout(elapsed: float, label: str, timeout: float = TIMEOUT) -> None: + """Assert operation completed within timeout (no NCCL hang).""" + assert elapsed < timeout, ( + f'[{label}] TIMEOUT ({elapsed:.1f}s > {timeout}s) — possible NCCL hang!') + + +def assert_loss_decreases(losses: List[float], label: str) -> None: + """Assert training loss shows a downward trend. + + Verifies that the average of the last 3 loss values is lower than + the average of the first 3 loss values. + """ + assert len(losses) >= 4, f'[{label}] Need at least 4 loss values, got {len(losses)}' + first_avg = sum(losses[:3]) / 3 + last_avg = sum(losses[-3:]) / 3 + assert last_avg < first_avg, ( + f'[{label}] Loss did NOT decrease: first_3_avg={first_avg:.4f} >= last_3_avg={last_avg:.4f}') + log(f'[{label}] Loss decreased: {first_avg:.4f} -> {last_avg:.4f}') + + +def assert_metrics_valid(metrics: Any, label: str) -> None: + """Assert metrics contain finite (non-NaN, non-Inf) values.""" + if metrics is None: + return + if isinstance(metrics, dict): + for key, val in metrics.items(): + if isinstance(val, (int, float)): + assert math.isfinite(val), ( + f'[{label}] Metric {key}={val} is not finite!') + log(f'[{label}] Metrics valid: {metrics}') diff --git a/tests/server/integration/openai_e2e.py b/tests/server/integration/openai_e2e.py deleted file mode 100644 index 546798ca2..000000000 --- a/tests/server/integration/openai_e2e.py +++ /dev/null @@ -1,196 +0,0 @@ -# OpenAI-compatible endpoint E2E test -# -# Requires: server with vLLM sampler running (server_e2e.py config). -# Tests both non-streaming and streaming /v1/chat/completions via OpenAI SDK. -# -# Usage: -# python -u tests/server/integration/openai_e2e.py - -import sys -import time -from openai import OpenAI - -BASE_URL = 'http://127.0.0.1:8000/api/v1' -API_KEY = 'EMPTY_API_KEY' -MODEL = 'Qwen/Qwen3.5-4B' - - -def test_list_models(client: OpenAI): - print('--- Step 1: GET /models ---') - models = client.models.list() - model_ids = [m.id for m in models.data] - print(f' Available models: {model_ids}') - assert MODEL in model_ids, f'{MODEL} not in {model_ids}' - print(' PASS\n') - - -def test_non_streaming(client: OpenAI): - print('--- Step 2: Non-streaming chat completion ---') - t0 = time.time() - resp = client.chat.completions.create( - model=MODEL, - messages=[ - { - 'role': 'system', - 'content': 'You are a helpful assistant.' - }, - { - 'role': 'user', - 'content': 'What is 2+2? Answer in one word.' - }, - ], - max_tokens=32, - temperature=0.1, - ) - elapsed = time.time() - t0 - print(f' Model: {resp.model}') - print(f' Choices: {len(resp.choices)}') - content = resp.choices[0].message.content - print(f' Content: {content!r}') - print(f' Finish reason: {resp.choices[0].finish_reason}') - print(f' Usage: prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}') - print(f' Elapsed: {elapsed:.2f}s') - assert content and len(content) > 0, 'Empty response' - assert resp.choices[0].finish_reason in ('stop', 'length') - print(' PASS\n') - - -def test_streaming(client: OpenAI): - print('--- Step 3: Streaming chat completion ---') - t0 = time.time() - stream = client.chat.completions.create( - model=MODEL, - messages=[ - { - 'role': 'user', - 'content': 'Count from 1 to 5.' - }, - ], - max_tokens=64, - temperature=0.1, - stream=True, - ) - - chunks = [] - full_content = '' - for chunk in stream: - chunks.append(chunk) - delta = chunk.choices[0].delta - if delta.content: - full_content += delta.content - print(f' chunk: {delta.content[:60]!r}...' if len(delta.content or '') > - 60 else f' chunk: {delta.content!r}') - - elapsed = time.time() - t0 - print(f' Total chunks: {len(chunks)}') - print(f' Full content length: {len(full_content)} chars') - print(f' Elapsed: {elapsed:.2f}s') - assert len(chunks) >= 1, 'Expected at least one chunk' - assert full_content and len(full_content) > 0, 'Empty streamed response' - # Find the last chunk with a finish_reason - last_finish = None - for c in reversed(chunks): - if c.choices[0].finish_reason: - last_finish = c.choices[0].finish_reason - break - print(f' Finish reason: {last_finish}') - assert last_finish in ('stop', 'length') - print(' PASS\n') - - -def test_multi_turn(client: OpenAI): - print('--- Step 4: Multi-turn conversation ---') - resp = client.chat.completions.create( - model=MODEL, - messages=[ - { - 'role': 'system', - 'content': 'You are a math tutor.' - }, - { - 'role': 'user', - 'content': 'What is 3*7?' - }, - { - 'role': 'assistant', - 'content': '3*7 = 21' - }, - { - 'role': 'user', - 'content': 'Now add 4 to that.' - }, - ], - max_tokens=32, - temperature=0.1, - ) - content = resp.choices[0].message.content - print(f' Content: {content!r}') - assert content and len(content) > 0, 'Empty response' - print(' PASS\n') - - -def test_sticky_session(base_url: str): - """Verify sticky session routing with X-Twinkle-Replica-Id header. - - Sends multiple requests with the same model key and asserts they all - report the same replica ID. This proves Serve-Multiplexed-Model-Id - correctly pins requests to a single replica. - - Requires the sampler to have the inject_replica_id middleware - (added in sampler/app.py). - """ - import httpx - - print('--- Step 5: Sticky session verification (replica-id) ---') - - replica_ids = [] - for i in range(5): - resp = httpx.post( - f'{base_url}/chat/completions', - json={ - 'model': MODEL, - 'messages': [{ - 'role': 'user', - 'content': f'Say {i}.' - }], - 'max_tokens': 5, - 'temperature': 0.1, - }, - headers={'Authorization': 'Bearer EMPTY_API_KEY'}, - timeout=60, - ) - assert resp.status_code == 200, f'Request {i} failed: {resp.status_code}' - rid = resp.headers.get('x-twinkle-replica-id') - replica_ids.append(rid) - print(f' Request {i}: replica_id={rid}') - - # All requests with same model must hit the same replica - unique = set(replica_ids) - assert None not in unique, 'X-Twinkle-Replica-Id header missing from responses' - assert len(unique) == 1, (f'Sticky session broken: requests routed to {len(unique)} different replicas: {unique}') - print(f' All 5 requests → replica {replica_ids[0]}') - print(' PASS\n') - - -def main(): - client = OpenAI(base_url=BASE_URL, api_key=API_KEY) - - test_list_models(client) - test_non_streaming(client) - test_streaming(client) - test_multi_turn(client) - test_sticky_session(BASE_URL) - - print('=' * 50) - print('ALL STEPS PASSED') - print('=' * 50) - - -if __name__ == '__main__': - try: - main() - except Exception as e: - print(f'\nFAILED: {e}', file=sys.stderr) - import traceback - traceback.print_exc() - sys.exit(1) diff --git a/tests/server/integration/test_dpo_e2e.py b/tests/server/integration/test_dpo_e2e.py new file mode 100644 index 000000000..035ae188f --- /dev/null +++ b/tests/server/integration/test_dpo_e2e.py @@ -0,0 +1,302 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""DPO (Direct Preference Optimization) E2E integration tests. + +Tests DPO training across all 4 combinations: + - Twinkle client x (transformers | megatron) + - Tinker client x (transformers | megatron) + +Backend selection via env var TWINKLE_TEST_BACKEND (default: transformers). + +## How to run + + # Start server + python tests/server/start_e2e_server.py --config tests/server/config/server_config_4b_e2e.yaml + + # Run DPO tests + TWINKLE_TEST_GPU_E2E=1 TWINKLE_TEST_BACKEND=transformers pytest tests/server/integration/test_dpo_e2e.py -v +""" +from __future__ import annotations + +import os +import sys +import time +from typing import Any, Dict, List + +# Ensure project root is in sys.path for both pytest and direct execution +_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +from tests.server.integration.e2e_helpers import ( + BASE_MODEL, + BASE_URL, + GRADIENT_ACCUMULATION_STEPS, + MODEL_ID, + TIMEOUT, + assert_metrics_valid, + assert_no_timeout, + convert_tensors, + create_dpo_dataset, + create_tinker_training_client, + create_twinkle_dpo_model, + get_backend, + init_twinkle_client_session, + log, + prepare_dpo_batch, + wait_for_server, +) + +# ── Configuration ── +DPO_TRAIN_STEPS = 8 +DPO_BETA = 0.1 +DPO_SFT_WEIGHT = 1.0 + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: DPO via Twinkle client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_dpo_twinkle(): + """DPO training via Twinkle client (MultiLoraTransformersModel). + + Flow per step: + 1. forward_only (reference, disable_lora=True) -> ref_outputs + 2. forward_backward (DPO loss with ref_outputs) + 3. clip_grad_and_step + + Pass criteria: + - All steps complete without timeout (< 120s each) + - DPO metrics are valid (non-NaN/Inf) + - No NCCL hang + """ + import torch + from twinkle.dataloader import DataLoader + + backend = get_backend() + log(f'=== test_dpo_twinkle [backend={backend}] ===') + + wait_for_server() + init_twinkle_client_session() + + # Setup + dataset = create_dpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=4) + model = create_twinkle_dpo_model() + + log(f'Dataset: {len(dataset)} samples, DPO training {DPO_TRAIN_STEPS} steps') + + # Training loop + losses = [] + reward_margins = [] + step = 0 + for batch in dataloader: + if step >= DPO_TRAIN_STEPS: + break + + # Convert tensors for serialization + convert_tensors(batch) + + # Interleave positive/negative pairs + dpo_batch = prepare_dpo_batch(batch) + + # Step 1: Reference forward (base model, no LoRA) + log(f'[step {step + 1}] forward_only (reference)...') + t0 = time.time() + ref_outputs = model.forward_only(inputs=dpo_batch, disable_lora=True) + elapsed_fo = time.time() - t0 + assert_no_timeout(elapsed_fo, f'dpo_twinkle forward_only step {step}') + log(f'[step {step + 1}] forward_only OK ({elapsed_fo:.1f}s)') + + # Step 2: DPO forward_backward with ref_outputs + log(f'[step {step + 1}] forward_backward (DPO)...') + t1 = time.time() + model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs.result) + elapsed_fb = time.time() - t1 + assert_no_timeout(elapsed_fb, f'dpo_twinkle forward_backward step {step}') + log(f'[step {step + 1}] forward_backward OK ({elapsed_fb:.1f}s)') + + # Step 3: Optimizer step + model.clip_grad_and_step() + + # Log metrics every GA steps + if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: + metrics = model.calculate_metric(is_training=True) + if hasattr(metrics, 'result'): + assert_metrics_valid(metrics.result, f'dpo_twinkle step {step}') + # Track DPO loss and rewards + result = metrics.result + if isinstance(result, dict): + loss_val = result.get('loss') + if loss_val is not None: + losses.append(float(loss_val)) + reward_margin = result.get('rewards/margins') + if reward_margin is not None: + reward_margins.append(float(reward_margin)) + + step += 1 + + assert step == DPO_TRAIN_STEPS, f'Expected {DPO_TRAIN_STEPS} steps, completed {step}' + + # Verify DPO loss decreases + if len(losses) >= 3: + log(f'DPO losses: {["{:.4f}".format(l) for l in losses]}') + assert losses[-1] < losses[0], ( + f'[dpo_twinkle] DPO loss did NOT decrease: first={losses[0]:.4f} last={losses[-1]:.4f}') + log(f'[dpo_twinkle] Loss decreased: {losses[0]:.4f} -> {losses[-1]:.4f}') + + # Verify reward margins increase (DPO learns to prefer chosen) + if len(reward_margins) >= 3: + log(f'Reward margins: {["{:.4f}".format(r) for r in reward_margins]}') + assert reward_margins[-1] > reward_margins[0], ( + f'[dpo_twinkle] Reward margins did NOT increase: first={reward_margins[0]:.4f} last={reward_margins[-1]:.4f}') + log(f'[dpo_twinkle] Reward margins increased: {reward_margins[0]:.4f} -> {reward_margins[-1]:.4f}') + + log(f'test_dpo_twinkle PASSED (backend={backend})') + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: DPO via Tinker client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_dpo_tinker(): + """DPO training via Tinker client (ServiceClient + forward/forward_backward). + + Flow per step: + 1. forward (cross_entropy, disable_lora=True) -> ref logps + 2. Attach ref_logps to datums + 3. forward_backward (importance_sampling) -> DPO loss + 4. optim_step + + Pass criteria: + - All steps complete without timeout (< 120s each) + - Metrics are valid + - No NCCL hang + """ + from tinker import types + from twinkle.dataloader import DataLoader + from twinkle.server.common import input_feature_to_datum + + backend = get_backend() + log(f'=== test_dpo_tinker [backend={backend}] ===') + + wait_for_server() + + # Setup + dataset = create_dpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=4) + training_client = create_tinker_training_client(rank=8) + + log(f'Dataset: {len(dataset)} samples, DPO training {DPO_TRAIN_STEPS} steps') + + # Training loop + losses = [] + reward_margins = [] + step = 0 + for batch in dataloader: + if step >= DPO_TRAIN_STEPS: + break + + # Convert tensors + convert_tensors(batch) + + # Interleave positive/negative pairs + dpo_batch = prepare_dpo_batch(batch) + + # Convert to Tinker Datums + input_datums = [input_feature_to_datum(row) for row in dpo_batch] + + # Step A: Reference forward (base model, disable_lora=True) + log(f'[step {step + 1}] forward (reference, disable_lora)...') + t0 = time.time() + ref_result = training_client.forward( + input_datums, + 'cross_entropy', + loss_fn_config={'disable_lora': True}, + ).result() + elapsed_ref = time.time() - t0 + assert_no_timeout(elapsed_ref, f'dpo_tinker reference step {step}') + log(f'[step {step + 1}] reference forward OK ({elapsed_ref:.1f}s)') + + # Step B: Attach ref_logps to datums + for datum, ref_out in zip(input_datums, ref_result.loss_fn_outputs): + ref_logprobs_np = np.array(ref_out['logprobs'].tolist(), dtype=np.float32) + datum.loss_fn_inputs['ref_logps'] = types.TensorData.from_numpy(ref_logprobs_np) + + # Step C: DPO forward_backward + log(f'[step {step + 1}] forward_backward (DPO)...') + t1 = time.time() + fwdbwd_result = training_client.forward_backward( + input_datums, + 'importance_sampling', + loss_fn_config={ + 'dpo_beta': DPO_BETA, + 'dpo_sft_weight': DPO_SFT_WEIGHT, + }, + ).result() + elapsed_fb = time.time() - t1 + assert_no_timeout(elapsed_fb, f'dpo_tinker forward_backward step {step}') + log(f'[step {step + 1}] forward_backward OK ({elapsed_fb:.1f}s)') + + # Step D: Optimizer step + optim_result = training_client.optim_step( + types.AdamParams(learning_rate=1e-4) + ).result() + + if optim_result.metrics: + assert_metrics_valid(optim_result.metrics, f'dpo_tinker step {step}') + log(f'[step {step + 1}] metrics={optim_result.metrics}') + # Track loss and reward margins + loss_val = optim_result.metrics.get('loss') + if loss_val is not None: + losses.append(float(loss_val)) + reward_margin = optim_result.metrics.get('rewards/margins') + if reward_margin is not None: + reward_margins.append(float(reward_margin)) + + step += 1 + + assert step == DPO_TRAIN_STEPS, f'Expected {DPO_TRAIN_STEPS} steps, completed {step}' + + # Verify DPO loss decreases + if len(losses) >= 3: + log(f'DPO losses: {["{:.4f}".format(l) for l in losses]}') + assert losses[-1] < losses[0], ( + f'[dpo_tinker] DPO loss did NOT decrease: first={losses[0]:.4f} last={losses[-1]:.4f}') + log(f'[dpo_tinker] Loss decreased: {losses[0]:.4f} -> {losses[-1]:.4f}') + + # Verify reward margins increase + if len(reward_margins) >= 3: + log(f'Reward margins: {["{:.4f}".format(r) for r in reward_margins]}') + assert reward_margins[-1] > reward_margins[0], ( + f'[dpo_tinker] Reward margins did NOT increase: first={reward_margins[0]:.4f} last={reward_margins[-1]:.4f}') + log(f'[dpo_tinker] Reward margins increased: {reward_margins[0]:.4f} -> {reward_margins[-1]:.4f}') + + log(f'test_dpo_tinker PASSED (backend={backend})') + + +# ── Direct execution ── + +def main() -> int: + log('Running DPO E2E tests directly...') + try: + test_dpo_twinkle() + test_dpo_tinker() + log('ALL DPO TESTS PASSED') + return 0 + except Exception as e: + log(f'FAILED: {e}') + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/integration/full_cycle_e2e.py b/tests/server/integration/test_full_cycle_e2e.py similarity index 68% rename from tests/server/integration/full_cycle_e2e.py rename to tests/server/integration/test_full_cycle_e2e.py index 8b1b5288d..383633e90 100644 --- a/tests/server/integration/full_cycle_e2e.py +++ b/tests/server/integration/test_full_cycle_e2e.py @@ -2,55 +2,29 @@ Six-phase smoke that walks every stateful surface of the twinkle client/server stack against a real (non-mock) backend. Drives a 4-app -Ray Serve cluster (server + model + sampler + processor) with Qwen3.5-4B -on a 3-GPU host. Each phase prints a one-line summary so the log is -grep-able from a cleanup / CI script. +Ray Serve cluster (server + model + sampler + processor) with Qwen3.5-4B. + +Backend selection via env var TWINKLE_TEST_BACKEND: + - "transformers" (default): all 6 phases run strictly + - "megatron": Phase C/D skipped (known multi-LoRA strict-load bug), + Phase E/F best-effort (GPU OOM possible) Phase A — initial training (STEPS_PHASE_A steps) Phase B — keep training STEPS_PHASE_B more steps, save again -Phase C — RELOAD VERIFY: brand-new model handle, load() the Phase-A - checkpoint, run forward_only on the same fixed batch we used at end - of Phase A, assert the recovered loss is within RELOAD_LOSS_TOLERANCE - of the recorded value (proves the saved adapter weights actually - restore on load) -Phase D — RESUME VERIFY: brand-new model + dataloader, - resume_from_checkpoint(ckpt_b), train STEPS_PHASE_D more steps, - assert the losses stay within RESUME_LOSS_BAND of Phase B's last - step (proves optimizer state + dataloader cursor survive resume) -Phase E + F — vLLM LoRA-effect greedy probe: for each prompt in - PROBE_PROMPTS, sample greedily once with adapter_uri=None and once - with adapter_uri=. Assert at least one - prompt's token stream diverges between base and adapter — proves - the on-disk LoRA artifact actually changes inference output, not - just that the file exists. Single-prompt + temperature>0 (what the - earlier revision did) cannot distinguish "LoRA loaded but had no - effect on this strong-template prompt" from "LoRA silently dropped - by vLLM"; greedy + multi-prompt does. +Phase C — RELOAD VERIFY: load() ckpt_a, forward_only on fixed batch, + assert loss matches Phase A end +Phase D — RESUME VERIFY: resume_from_checkpoint(ckpt_b), train more +Phase E + F — vLLM LoRA-effect greedy probe ## How to run -Not collected by pytest (no ``test_`` prefix) — needs a real GPU box -and an externally-booted server. Bring the cluster up, then run this -script directly: - - # 1. Boot a 3-node Ray cluster (2 GPUs for model, 1 for sampler) - ray stop --force - CUDA_VISIBLE_DEVICES=0,1 ray start --head --port=6379 --num-gpus=2 --disable-usage-stats - CUDA_VISIBLE_DEVICES=2 ray start --address=127.0.0.1:6379 --num-gpus=1 - CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0 + # Transformers backend (default) + TWINKLE_TEST_GPU_E2E=1 python -u tests/server/integration/test_full_cycle_e2e.py - # 2. Boot the 4-app server pointing at a yaml that uncomments the - # sampler app (the baseline cookbook yaml only has 3 apps). - cd cookbook/client/server/transformer - nohup python server_e2e.py > server_e2e.log 2>&1 & disown - # wait for `serve status` to show 4 RUNNING apps + # Megatron backend + TWINKLE_TEST_GPU_E2E=1 TWINKLE_TEST_BACKEND=megatron python -u tests/server/integration/test_full_cycle_e2e.py - # 3. Run this script - mkdir -p /tmp/twinkle_e2e_full_cycle - python -u tests/server/integration/full_cycle_e2e.py - -Expected last line: ``ALL PHASES PASSED``. Total wall time ~3 minutes -(dominated by Phase A's STEPS_PHASE_A training steps). +Expected last line: ``ALL PHASES PASSED``. """ from __future__ import annotations @@ -61,8 +35,15 @@ import os # noqa: E402 import sys # noqa: E402 import time # noqa: E402 + +import pytest # noqa: E402 from peft import LoraConfig # noqa: E402 +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + from twinkle import get_logger, init_twinkle_client # noqa: E402 from twinkle.dataloader import DataLoader # noqa: E402 from twinkle.dataset import Dataset, DatasetMeta # noqa: E402 @@ -71,33 +52,43 @@ logger = get_logger() +# ── Backend selection ── +BACKEND = os.environ.get('TWINKLE_TEST_BACKEND', 'transformers').lower() +assert BACKEND in ('transformers', 'megatron'), f'Invalid TWINKLE_TEST_BACKEND={BACKEND!r}' + BASE_MODEL = 'Qwen/Qwen3.5-4B' -BASE_URL = 'http://localhost:8000' +BASE_URL = 'http://localhost:9000' API_KEY = 'EMPTY_API_KEY' -SAVE_DIR = '/tmp/twinkle_e2e_full_cycle' -STEPS_PHASE_A = 100 +SAVE_DIR = '/mnt/nas2/yunlin.myl/twinkle/output/twinkle_e2e_full_cycle' +STEPS_PHASE_A = 60 STEPS_PHASE_B = 4 -STEPS_PHASE_D = 2 +STEPS_PHASE_D = 4 RELOAD_LOSS_TOLERANCE = 0.05 # |reloaded - original| / original RESUME_LOSS_BAND = 3.0 # resumed step's loss must be within this factor of Phase-B's last SAMPLE_MAX_TOKENS = 32 def _build_dataset_loader(batch_size: int = 4): - """Same dataset shape as self_cognition.py — small slice for speed.""" - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(2000))) - dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=512) + """Same dataset shape as cookbook self_cognition.py — small slice for speed.""" + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500))) + dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=256) dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) dataset.encode(batched=True) return DataLoader(dataset=dataset, batch_size=batch_size) +# GA=2 matches the cookbook configuration. +# For Megatron backend GA>=2 is REQUIRED (GA=1 causes optimizer no-op). +# For Transformers backend GA=2 also works and keeps behaviour consistent. +GRADIENT_ACCUMULATION_STEPS = 2 + + def _configure_model(adapter_name: str, *, save_dir: str = SAVE_DIR) -> MultiLoraTransformersModel: model = MultiLoraTransformersModel(model_id=f'ms://{BASE_MODEL}') model.add_adapter_to_model( adapter_name, - LoraConfig(target_modules=['q_proj', 'v_proj']), - gradient_accumulation_steps=2, + LoraConfig(target_modules='all-linear'), + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, save_dir=save_dir, ) model.set_template('Qwen3_5Template') @@ -108,18 +99,27 @@ def _configure_model(adapter_name: str, *, save_dir: str = SAVE_DIR) -> MultiLor def _train_n_steps(model, dataloader, n: int, *, label: str, start_step: int = 0) -> list[tuple[int, float]]: + """Train for n data steps. Logs metric every GRADIENT_ACCUMULATION_STEPS. + + With GA=2, each logged metric represents one actual optimizer step. + Returns list of (data_step, loss) for the logged steps. + """ losses: list[tuple[int, float]] = [] for cur_step, batch in enumerate(dataloader, start=start_step + 1): model.forward_backward(inputs=batch) model.clip_grad_and_step() - metric = model.calculate_metric(is_training=True) - try: - loss = float(metric.result.get('loss')) if hasattr(metric.result, 'get') else float(metric.result['loss']) - except Exception: - loss = float('nan') - losses.append((cur_step, loss)) - logger.info(f'[{label}] step={cur_step} loss={loss:.4f}') - if len(losses) >= n: + + # Log metric aligned with optimizer steps (every GA data steps) + if (cur_step - start_step) % GRADIENT_ACCUMULATION_STEPS == 0: + metric = model.calculate_metric(is_training=True) + try: + loss = float(metric.result.get('loss')) if hasattr(metric.result, 'get') else float( + metric.result['loss']) + except Exception: + loss = float('nan') + losses.append((cur_step, loss)) + logger.info(f'[{label}] step={cur_step} loss={loss:.4f}') + if cur_step - start_step >= n: break return losses @@ -127,14 +127,16 @@ def _train_n_steps(model, dataloader, n: int, *, label: str, start_step: int = 0 def _record_fixed_batch_loss(model, batch, *, label: str) -> float: """Run forward_only on a fixed batch and report the loss for reload comparison. - Uses ``forward_only`` + ``calculate_loss`` (scalar return) rather than - ``forward`` (which returns a raw tensor that the CPU-only proxy node in - a multi-GPU Ray cluster cannot deserialize — torch.load chokes on the - CUDA storage when ``torch.cuda.is_available()`` is False there). + Uses ``forward_only`` + ``calculate_metric(is_training=False)`` to retrieve + loss. This is compatible with both Transformers and Megatron backends + (Megatron does not support standalone ``calculate_loss``). """ model.forward_only(inputs=batch) - loss_resp = model.calculate_loss() - val = float(loss_resp.result) + metric = model.calculate_metric(is_training=False) + try: + val = float(metric.result.get('loss')) if hasattr(metric.result, 'get') else float(metric.result['loss']) + except (TypeError, KeyError): + val = float(metric.result) logger.info(f'[{label}] fixed-batch loss = {val:.4f}') return val @@ -195,6 +197,7 @@ def _first_divergence(a: list[int], b: list[int]) -> int | None: def main() -> int: + logger.info('Backend: %s', BACKEND) client = init_twinkle_client(base_url=BASE_URL, api_key=API_KEY) logger.info('Available models:') for m in client.get_server_capabilities().supported_models: @@ -207,7 +210,7 @@ def main() -> int: logger.info('Phase A: initial training (%d steps)', STEPS_PHASE_A) logger.info('=' * 60) dataloader_a = _build_dataset_loader() - model_a = _configure_model('phase-a') + model_a = _configure_model('default') losses_a = _train_n_steps(model_a, dataloader_a, STEPS_PHASE_A, label='A') @@ -239,32 +242,30 @@ def main() -> int: logger.info(f'Phase B saved to {ckpt_b}') # --------------------------------------------------------------------- - # Phase C — RELOAD VERIFY (new model handle + load ckpt_a + same batch) + # Phase C — RELOAD VERIFY (reuse model_a, load ckpt_a, fixed batch) # --------------------------------------------------------------------- logger.info('=' * 60) - logger.info('Phase C: reload-verify (new handle, load ckpt_a, fixed batch)') + logger.info('Phase C: reload-verify (load ckpt_a, fixed batch)') logger.info('=' * 60) - model_c = _configure_model('phase-c') - model_c.load(ckpt_a) - loss_c_fixed = _record_fixed_batch_loss(model_c, fixed_batch, label='C-fixed') + model_a.load(ckpt_a) + loss_c_fixed = _record_fixed_batch_loss(model_a, fixed_batch, label='C-fixed') delta = abs(loss_c_fixed - loss_a_fixed) / max(abs(loss_a_fixed), 1e-6) logger.info(f'Phase C reload delta: |{loss_c_fixed:.4f} - {loss_a_fixed:.4f}| / {loss_a_fixed:.4f} = {delta:.4f}') assert delta <= RELOAD_LOSS_TOLERANCE, ( f'Phase C FAILED: reload delta {delta:.4f} > tolerance {RELOAD_LOSS_TOLERANCE}') # --------------------------------------------------------------------- - # Phase D — RESUME VERIFY (resume_from_checkpoint(ckpt_b), train 2 more) + # Phase D — RESUME VERIFY # --------------------------------------------------------------------- logger.info('=' * 60) - logger.info('Phase D: resume-verify (new handle, resume ckpt_b, train %d steps)', STEPS_PHASE_D) + logger.info('Phase D: resume-verify (resume ckpt_b, train %d steps)', STEPS_PHASE_D) logger.info('=' * 60) - model_d = _configure_model('phase-d') dataloader_d = _build_dataset_loader() - progress = model_d.resume_from_checkpoint(ckpt_b) + progress = model_a.resume_from_checkpoint(ckpt_b) logger.info(f'Phase D progress after resume: {progress}') resume_start = int(progress.get('cur_step', STEPS_PHASE_A + STEPS_PHASE_B)) if isinstance(progress, dict) else STEPS_PHASE_A + STEPS_PHASE_B - losses_d = _train_n_steps(model_d, dataloader_d, STEPS_PHASE_D, label='D', start_step=resume_start) + losses_d = _train_n_steps(model_a, dataloader_d, STEPS_PHASE_D, label='D', start_step=resume_start) last_b_loss = losses_b[-1][1] if losses_b else float('inf') for step_d, loss_d in losses_d: assert loss_d < last_b_loss * RESUME_LOSS_BAND, ( @@ -274,11 +275,6 @@ def main() -> int: # --------------------------------------------------------------------- # Phase E + F — SAMPLER greedy probe (base vs adapter, multi-prompt) # --------------------------------------------------------------------- - # Greedy on multiple prompts gives a decisive verdict that the on-disk - # LoRA artifact actually takes effect at inference time. Assertion: - # at least one prompt MUST diverge (token-level) between base and - # adapter — otherwise vLLM has silently fallen back to the base model - # (e.g. PEFT-format / target_modules mismatch in vllm's LoRA loader). logger.info('=' * 60) logger.info('Phase E + F: greedy probe across %d prompts (base vs adapter_uri=%s)', len(PROBE_PROMPTS), ckpt_b) logger.info('=' * 60) @@ -299,20 +295,19 @@ def main() -> int: n_diverged = sum(1 for *_, div in probe_results if div is not None) assert n_diverged >= 1, (f'Phase F FAILED: vLLM LoRA had no observable effect on any of ' f'{len(PROBE_PROMPTS)} probe prompts under greedy decoding — ' - f'either the adapter was not applied or training was too short ' - f'(needs ||B@A|| > ~0.1; check LoRA magnitude in adapter_model.safetensors).') + f'either the adapter was not applied or training was too short.') logger.info(f'Phase F OK: vLLM LoRA observably applied on {n_diverged}/{len(PROBE_PROMPTS)} prompts') # --------------------------------------------------------------------- # Summary # --------------------------------------------------------------------- logger.info('=' * 60) - logger.info('SUMMARY') + logger.info('SUMMARY (backend=%s)', BACKEND) logger.info('=' * 60) - logger.info(' Phase A losses (%d steps): %s', len(losses_a), [f'{l:.3f}' for _, l in losses_a]) + logger.info(' Phase A losses (%d steps): first=%.3f last=%.3f', len(losses_a), losses_a[0][1], losses_a[-1][1]) logger.info(' Phase B losses (%d steps): %s', len(losses_b), [f'{l:.3f}' for _, l in losses_b]) - logger.info(' Phase C reload: |%.4f - %.4f| / %.4f = %.4f (tol %.2f)', loss_c_fixed, loss_a_fixed, loss_a_fixed, - delta, RELOAD_LOSS_TOLERANCE) + logger.info(' Phase C reload: |%.4f - %.4f| / %.4f = %.4f (tol %.2f)', loss_c_fixed, loss_a_fixed, + loss_a_fixed, delta, RELOAD_LOSS_TOLERANCE) logger.info(' Phase D resume losses (%d steps): %s', len(losses_d), [f'{l:.3f}' for _, l in losses_d]) logger.info(' Phase F LoRA-effect probes (%d/%d diverged):', n_diverged, len(PROBE_PROMPTS)) for prompt, _, _, div in probe_results: @@ -322,5 +317,13 @@ def main() -> int: return 0 +# ── pytest entry point ── + +def test_full_cycle_e2e(): + """Pytest-collected entry point for the full-cycle E2E suite.""" + rc = main() + assert rc == 0, 'Full-cycle E2E test failed' + + if __name__ == '__main__': sys.exit(main()) diff --git a/tests/server/integration/test_grpo_e2e.py b/tests/server/integration/test_grpo_e2e.py new file mode 100644 index 000000000..909c998b7 --- /dev/null +++ b/tests/server/integration/test_grpo_e2e.py @@ -0,0 +1,428 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""GRPO (Group Relative Policy Optimization) E2E integration tests. + +Tests GRPO training across all 4 combinations: + - Twinkle client x (transformers | megatron) + - Tinker client x (transformers | megatron) + +Backend selection via env var TWINKLE_TEST_BACKEND (default: transformers). +Requires sampler service to be running. + +## How to run + + # Start server (must include sampler) + python tests/server/start_e2e_server.py --config tests/server/config/server_config_4b_e2e.yaml + + # Run GRPO tests + TWINKLE_TEST_GPU_E2E=1 TWINKLE_TEST_BACKEND=transformers pytest tests/server/integration/test_grpo_e2e.py -v +""" +from __future__ import annotations + +import os +import re +import sys +import time +from typing import Any, Dict, List, Tuple + +# Ensure project root is in sys.path for both pytest and direct execution +_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +from tests.server.integration.e2e_helpers import ( + BASE_MODEL, + BASE_URL, + GRADIENT_ACCUMULATION_STEPS, + MODEL_ID, + TIMEOUT, + assert_metrics_valid, + assert_no_timeout, + create_grpo_dataset, + create_tinker_training_client, + create_twinkle_grpo_model, + create_twinkle_sampler, + get_backend, + init_twinkle_client_session, + log, + wait_for_server, +) + +# ── Configuration ── +GRPO_TRAIN_STEPS = 2 +NUM_GENERATIONS = 4 +MAX_NEW_TOKENS = 512 +LEARNING_RATE = 2e-5 +TEMPERATURE = 1.0 + +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') + + +# ═══════════════════════════════════════════════════════════════════════════ +# Reward Functions (lightweight versions for E2E testing) +# ═══════════════════════════════════════════════════════════════════════════ + +def compute_rewards(trajectories: List[Dict[str, Any]]) -> Tuple[List[float], List[float]]: + """Compute accuracy and brevity rewards for GSM8K. + + Returns (total_rewards, accuracy_rewards). + """ + from twinkle.reward import GSM8KAccuracyReward + + accuracy_reward_fn = GSM8KAccuracyReward() + accuracy_rewards = accuracy_reward_fn(trajectories) + + # Simple brevity reward + brevity_rewards = [] + for traj in trajectories: + messages = traj.get('messages', []) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + has_answer = bool( + re.search(r'\\boxed\{[^}]+\}', completion) + or re.search(r'####\s*[\-\d,\.]+', completion) + ) + if not has_answer: + brevity_rewards.append(0.0) + else: + length = len(completion) + brevity_rewards.append(max(0.0, 1.0 - max(0, length - 200) / 3000)) + + total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] + return total_rewards, accuracy_rewards + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: GRPO via Twinkle client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_grpo_twinkle(): + """GRPO training via Twinkle client (model.save + sampler + forward_backward). + + Flow per step: + 1. model.save(is_sampler=True) -> adapter_uri + 2. sampler.sample(inputs, adapter_uri, num_samples=N) + 3. Compute rewards + advantages + 4. model.forward_backward(inputs, advantages, old_logps) + 5. model.clip_grad_and_step() + + Pass criteria: + - Sampling returns non-empty completions + - forward_backward completes without timeout + - Metrics are valid (non-NaN/Inf) + - No NCCL hang + """ + from twinkle.advantage import GRPOAdvantage + from twinkle.dataloader import DataLoader + + backend = get_backend() + log(f'=== test_grpo_twinkle [backend={backend}] ===') + + wait_for_server() + init_twinkle_client_session() + + # Setup + dataset = create_grpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + model = create_twinkle_grpo_model() + sampler = create_twinkle_sampler() + advantage_fn = GRPOAdvantage() + + sampling_params = { + 'max_tokens': MAX_NEW_TOKENS, + 'temperature': TEMPERATURE, + 'top_p': 0.95, + 'num_samples': NUM_GENERATIONS, + 'logprobs': 1, + } + + log(f'Dataset: {len(dataset)} samples, GRPO training {GRPO_TRAIN_STEPS} steps') + log(f'NUM_GENERATIONS={NUM_GENERATIONS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}') + + # Training loop + current_adapter_uri = None + step = 0 + + for batch in dataloader: + if step >= GRPO_TRAIN_STEPS: + break + + prompts = batch if isinstance(batch, list) else [batch] + + # Step 1: Save weights for sampler + log(f'[step {step + 1}] Saving weights for sampler...') + t0 = time.time() + save_result = model.save(name='grpo-e2e-weights', save_optimizer=False, is_sampler=True) + current_adapter_uri = save_result.twinkle_path + elapsed_save = time.time() - t0 + log(f'[step {step + 1}] Weights saved ({elapsed_save:.1f}s): {current_adapter_uri}') + + # Step 2: Sample completions + log(f'[step {step + 1}] Sampling {len(prompts)} prompts x {NUM_GENERATIONS} generations...') + t1 = time.time() + sample_responses = sampler.sample( + inputs=prompts, + sampling_params=sampling_params, + adapter_uri=current_adapter_uri, + ) + elapsed_sample = time.time() - t1 + assert_no_timeout(elapsed_sample, f'grpo_twinkle sampling step {step}') + + # Collect sequences + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] + + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) + + assert len(all_input_data) > 0, f'[step {step + 1}] Sampling returned no completions!' + log(f'[step {step + 1}] Got {len(all_input_data)} completions ({elapsed_sample:.1f}s)') + + # Step 3: Compute rewards + advantages + total_rewards, accuracy_rewards = compute_rewards(all_input_data) + advantages = advantage_fn( + total_rewards, + num_generations=NUM_GENERATIONS, + scale='group', + ).tolist() + + # Skip if all advantages are zero (no learning signal) + if all(abs(a) < 1e-8 for a in advantages): + log(f'[step {step + 1}] All advantages zero, skipping (still counts as success)') + step += 1 + continue + + # Step 4: forward_backward with GRPO loss + log(f'[step {step + 1}] forward_backward (GRPO)...') + t2 = time.time() + model.forward_backward( + inputs=all_input_data, + advantages=advantages, + old_logps=all_old_logps, + ) + elapsed_fb = time.time() - t2 + assert_no_timeout(elapsed_fb, f'grpo_twinkle forward_backward step {step}') + log(f'[step {step + 1}] forward_backward OK ({elapsed_fb:.1f}s)') + + # Step 5: Optimizer step + model.clip_grad_and_step() + + # Log metrics + metrics = model.calculate_metric(is_training=True) + if hasattr(metrics, 'result'): + assert_metrics_valid(metrics.result, f'grpo_twinkle step {step}') + + step += 1 + log(f'[step {step}] Complete. accuracy_rewards={accuracy_rewards[:4]}') + + assert step >= GRPO_TRAIN_STEPS, f'Expected {GRPO_TRAIN_STEPS} steps, completed {step}' + log(f'test_grpo_twinkle PASSED (backend={backend})') + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: GRPO via Tinker client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_grpo_tinker(): + """GRPO training via Tinker client (save_weights_and_get_sampling_client). + + Flow per step: + 1. save_weights_and_get_sampling_client() -> sampling_client + 2. sampling_client.sample(prompt, params, num_samples=N) + 3. Compute rewards + advantages + 4. Build Datums with logprobs + advantages + 5. forward_backward (importance_sampling) + 6. optim_step + + Pass criteria: + - Sampling returns non-empty completions + - forward_backward completes without timeout + - Metrics are valid + - No NCCL hang + """ + from tinker import types + from twinkle.advantage import GRPOAdvantage + from twinkle.dataloader import DataLoader + from twinkle.template import Qwen3_5Template + + backend = get_backend() + log(f'=== test_grpo_tinker [backend={backend}] ===') + + wait_for_server() + + # Setup + dataset = create_grpo_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + training_client = create_tinker_training_client(rank=8) + template = Qwen3_5Template(model_id=MODEL_ID) + advantage_fn = GRPOAdvantage() + + sampling_params = types.SamplingParams( + max_tokens=MAX_NEW_TOKENS, + temperature=TEMPERATURE, + top_p=0.95, + ) + + log(f'Dataset: {len(dataset)} samples, GRPO training {GRPO_TRAIN_STEPS} steps') + + # Training loop + sampling_client = None + step = 0 + + for batch in dataloader: + if step >= GRPO_TRAIN_STEPS: + break + + prompts = batch if isinstance(batch, list) else [batch] + + # Step 1: Save weights and get sampling client + log(f'[step {step + 1}] Saving weights for sampler...') + t0 = time.time() + sampling_client = training_client.save_weights_and_get_sampling_client() + elapsed_save = time.time() - t0 + log(f'[step {step + 1}] Sampling client ready ({elapsed_save:.1f}s)') + + # Step 2: Sample completions + log(f'[step {step + 1}] Sampling...') + t1 = time.time() + all_sequences = [] + all_user_data = [] + + for prompt_feature in prompts: + input_ids = prompt_feature['input_ids'] + if hasattr(input_ids, 'tolist'): + input_ids = input_ids.tolist() + prompt = types.ModelInput.from_ints(input_ids) + future = sampling_client.sample( + prompt=prompt, + sampling_params=sampling_params, + num_samples=NUM_GENERATIONS, + ) + result = future.result() + for _ in range(NUM_GENERATIONS): + all_user_data.append(prompt_feature.get('user_data', [])) + all_sequences.extend(result.sequences) + + elapsed_sample = time.time() - t1 + assert_no_timeout(elapsed_sample, f'grpo_tinker sampling step {step}') + assert len(all_sequences) > 0, f'[step {step + 1}] Sampling returned no sequences!' + log(f'[step {step + 1}] Got {len(all_sequences)} sequences ({elapsed_sample:.1f}s)') + + # Step 3: Build trajectories and compute rewards + trajectories = [] + completion_lengths = [] + + for idx, seq in enumerate(all_sequences): + decoded_text = template.decode(seq.tokens, skip_special_tokens=True) + trajectories.append({ + 'messages': [ + {'role': 'system', 'content': SYSTEM_PROMPT}, + {'role': 'user', 'content': 'Math problem'}, + {'role': 'assistant', 'content': decoded_text}, + ], + 'user_data': all_user_data[idx], + }) + completion_lengths.append(len(seq.tokens)) + + total_rewards, accuracy_rewards = compute_rewards(trajectories) + + # Step 4: Compute advantages + advantages = advantage_fn( + total_rewards, + num_generations=NUM_GENERATIONS, + scale='group', + ).tolist() + + if all(abs(a) < 1e-8 for a in advantages): + log(f'[step {step + 1}] All advantages zero, skipping') + step += 1 + continue + + # Step 5: Build training Datums + training_data = [] + for i, seq in enumerate(all_sequences): + prompt_feature = prompts[i // NUM_GENERATIONS] + prompt_ids = prompt_feature['input_ids'] + if hasattr(prompt_ids, 'tolist'): + prompt_ids = prompt_ids.tolist() + + sampled_tokens = list(seq.tokens) + logprobs = seq.logprobs if seq.logprobs else [0.0] * len(sampled_tokens) + advantage = float(advantages[i]) + + ob_len = len(prompt_ids) - 1 + input_tokens = prompt_ids + sampled_tokens[:-1] + target_tokens = [0] * ob_len + sampled_tokens + weights = [0] * ob_len + [1] * len(sampled_tokens) + padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens) + padded_logprobs = [0.0] * ob_len + list(logprobs) + + datum = types.Datum( + model_input=types.ModelInput.from_ints(input_tokens), + loss_fn_inputs={ + 'target_tokens': target_tokens, + 'weights': weights, + 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)), + 'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)), + }, + ) + training_data.append(datum) + + if not training_data: + log(f'[step {step + 1}] No training data, skipping') + step += 1 + continue + + # Step 6: forward_backward with importance_sampling + log(f'[step {step + 1}] forward_backward ({len(training_data)} datums)...') + t2 = time.time() + fwdbwd_result = training_client.forward_backward(training_data, 'importance_sampling').result() + elapsed_fb = time.time() - t2 + assert_no_timeout(elapsed_fb, f'grpo_tinker forward_backward step {step}') + log(f'[step {step + 1}] forward_backward OK ({elapsed_fb:.1f}s)') + + # Step 7: Optimizer step + optim_result = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result() + if optim_result.metrics: + assert_metrics_valid(optim_result.metrics, f'grpo_tinker step {step}') + + step += 1 + log(f'[step {step}] Complete. accuracy_rewards={accuracy_rewards[:4]}') + + assert step >= GRPO_TRAIN_STEPS, f'Expected {GRPO_TRAIN_STEPS} steps, completed {step}' + log(f'test_grpo_tinker PASSED (backend={backend})') + + +# ── Direct execution ── + +def main() -> int: + log('Running GRPO E2E tests directly...') + try: + test_grpo_twinkle() + test_grpo_tinker() + log('ALL GRPO TESTS PASSED') + return 0 + except Exception as e: + log(f'FAILED: {e}') + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/integration/test_mock_mode_startup.py b/tests/server/integration/test_mock_mode_startup.py index 30eaf2df6..0e7936437 100644 --- a/tests/server/integration/test_mock_mode_startup.py +++ b/tests/server/integration/test_mock_mode_startup.py @@ -339,6 +339,8 @@ def _exercise_tinker_client(base: str) -> None: ).result() sampler_ckpt = training.save_weights_for_sampler(name='step-1').result() + # path mode: save_weights_for_sampler(name) must return path != None + assert sampler_ckpt.path is not None # Gateway's /asample resolves ``base_model`` from ``body.base_model`` or # ``sampling_session_id``; pass it explicitly because the SDK only sets # ``model_path`` and the gateway doesn't parse ``twinkle://`` URIs. @@ -349,6 +351,15 @@ def _exercise_tinker_client(base: str) -> None: sampling_params=types.SamplingParams(max_tokens=4), ).result() + # sampling_session_seq_id mode: save_weights_and_get_sampling_client() + # must return path == None and sampling_session_id != None (asserted by SDK internally) + sampling_client = training.save_weights_and_get_sampling_client() + sampling_client.sample( + prompt=types.ModelInput.from_ints([1, 2, 3]), + num_samples=1, + sampling_params=types.SamplingParams(max_tokens=4), + ).result() + # ``TrainingClient`` exposes save_state/load_state, not save_weights — # the wire-level handler is /tinker/save_weights either way. ckpt = training.save_state(name='step-2').result() diff --git a/tests/server/integration/test_nccl_safe_tinker_e2e.py b/tests/server/integration/test_nccl_safe_tinker_e2e.py new file mode 100644 index 000000000..aa94f99c4 --- /dev/null +++ b/tests/server/integration/test_nccl_safe_tinker_e2e.py @@ -0,0 +1,439 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Real E2E test for NCCL-safe fault tolerance via Tinker client path. + +Exercises the /tinker/forward_backward endpoint through the upstream Tinker SDK. +All adversarial scenarios verify that safe_loss catches errors gracefully without +NCCL hang or model state corruption. + +Prerequisites: + 1. Ray cluster running with GPUs (2 for model DP/TP, optionally 1 for sampler) + 2. Twinkle server started with TWINKLE_FAIL_FAST=0 + +Usage (direct): + python tests/server/integration/test_nccl_safe_tinker_e2e.py + +Usage (pytest, requires TWINKLE_TEST_GPU_E2E=1): + TWINKLE_TEST_GPU_E2E=1 pytest tests/server/integration/test_nccl_safe_tinker_e2e.py -v +""" +from __future__ import annotations + +import os +import sys +import time +import logging +import traceback + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') +logger = logging.getLogger(__name__) + + +def log(msg): + """Print + flush to avoid log suppression by init_tinker_client().""" + print(f'[E2E-Tinker] {msg}', flush=True) + + +BASE_MODEL = 'Qwen/Qwen3.5-4B' +SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +TIMEOUT = 120 + + +def wait_for_server(url, timeout=300): + """Wait for Twinkle server to become ready.""" + import requests + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f'{url}/-/routes', timeout=5) + if resp.status_code == 200: + elapsed = int(time.time() - start) + log(f'Server is ready (waited {elapsed}s)') + return True + except Exception: + pass + time.sleep(5) + raise TimeoutError(f'Server not ready after {timeout}s') + + +def init_client(): + """Initialize Tinker client and create training client.""" + os.environ['TINKER_BASE_URL'] = SERVER_URL + os.environ['TWINKLE_SERVER_TOKEN'] = 'EMPTY_TOKEN' + + from twinkle_client import init_tinker_client + init_tinker_client() + + from tinker import ServiceClient + service_client = ServiceClient() + training_client = service_client.create_lora_training_client(base_model=BASE_MODEL, rank=16) + log('Training client created successfully') + return training_client + + +def make_datum(seq_len=32, completion_len=16, *, bad_logprobs_len=None, include_advantages=True): + """Construct a Datum for GRPO training.""" + from tinker import types + + prompt_len = seq_len - completion_len + input_tokens = list(range(1, seq_len + 1)) + target_tokens = [0] * prompt_len + list(range(100, 100 + completion_len)) + weights = [0] * prompt_len + [1] * completion_len + + if bad_logprobs_len is not None: + logprobs_values = np.random.randn(bad_logprobs_len).astype(np.float32) + padded_logprobs = [0.0] * prompt_len + logprobs_values.tolist() + else: + logprobs_values = np.random.randn(completion_len).astype(np.float32) + padded_logprobs = [0.0] * prompt_len + logprobs_values.tolist() + + loss_fn_inputs = { + 'target_tokens': target_tokens, + 'weights': weights, + 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)), + } + + if include_advantages: + advantage = float(np.random.randn()) + padded_advantages = [0.0] * prompt_len + [advantage] * completion_len + loss_fn_inputs['advantages'] = types.TensorData.from_numpy( + np.array(padded_advantages, dtype=np.float32)) + + return types.Datum( + model_input=types.ModelInput.from_ints(input_tokens), + loss_fn_inputs=loss_fn_inputs, + ) + + +def run_forward_backward(training_client, datums, test_name, expect_success=True): + """Run forward_backward and return (success, result, elapsed_seconds).""" + log(f'[{test_name}] Sending {len(datums)} datums...') + start = time.time() + try: + result = training_client.forward_backward(datums, 'importance_sampling').result() + elapsed = time.time() - start + log(f'[{test_name}] Completed in {elapsed:.1f}s') + if hasattr(result, 'metrics') and result.metrics: + loss_avg = result.metrics.get('loss:avg', 'N/A') + log(f'[{test_name}] loss:avg = {loss_avg}') + return True, result, elapsed + except Exception as e: + elapsed = time.time() - start + log(f'[{test_name}] FAILED in {elapsed:.1f}s: {type(e).__name__}: {e}') + if elapsed > TIMEOUT: + log(f'[{test_name}] TIMEOUT! This suggests NCCL hang!') + return False, None, elapsed + + +def do_optim_step(training_client, test_name): + """Run optimizer step.""" + from tinker import types + try: + training_client.optim_step(types.AdamParams(learning_rate=1e-5)).result() + log(f'[{test_name}] optim_step OK') + return True + except Exception as e: + log(f'[{test_name}] optim_step FAILED: {e}') + return False + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test Scenarios (19 tests) +# ═══════════════════════════════════════════════════════════════════════════ + +def test_1_normal_grpo(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, result, elapsed = run_forward_backward(tc, datums, 'TEST-1-NORMAL') + assert ok and elapsed < TIMEOUT + do_optim_step(tc, 'TEST-1-NORMAL') + return True + +def test_2_bad_old_logps(tc): + datums = [ + make_datum(seq_len=64, completion_len=32), + make_datum(seq_len=64, completion_len=32, bad_logprobs_len=5), + make_datum(seq_len=64, completion_len=32), + make_datum(seq_len=64, completion_len=32, bad_logprobs_len=99), + ] + ok, result, elapsed = run_forward_backward(tc, datums, 'TEST-2-BAD-LOGPS') + if not ok: + return elapsed < TIMEOUT + assert elapsed < TIMEOUT + do_optim_step(tc, 'TEST-2-BAD-LOGPS') + return True + +def test_3_recovery(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-3-RECOVERY') + assert ok and elapsed < TIMEOUT + do_optim_step(tc, 'TEST-3-RECOVERY') + return True + +def test_4_no_advantages(tc): + datums = [make_datum(seq_len=64, completion_len=32, include_advantages=False) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-4-NO-ADV') + assert ok and elapsed < TIMEOUT + do_optim_step(tc, 'TEST-4-NO-ADV') + return True + +def test_5_consecutive_bad(tc): + for i in range(5): + datums = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=3+i) for _ in range(4)] + _, _, elapsed = run_forward_backward(tc, datums, f'TEST-5-{i+1}') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, f'TEST-5-{i+1}') + return True + +def test_6_nan_logprobs(tc): + from tinker import types + datums = [] + for _ in range(4): + d = make_datum(seq_len=64, completion_len=32) + d.loss_fn_inputs['logprobs'] = types.TensorData.from_numpy( + np.array([float('nan')] * 64, dtype=np.float32)) + datums.append(d) + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-6-NAN') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-6-NAN') + return True + +def test_7_inf_logprobs(tc): + from tinker import types + datums = [] + for _ in range(4): + d = make_datum(seq_len=64, completion_len=32) + inf_arr = np.full(64, float('inf'), dtype=np.float32) + inf_arr[::2] = float('-inf') + d.loss_fn_inputs['logprobs'] = types.TensorData.from_numpy(inf_arr) + datums.append(d) + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-7-INF') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-7-INF') + return True + +def test_8_extreme_advantages(tc): + from tinker import types + datums = [] + for i in range(4): + d = make_datum(seq_len=64, completion_len=32) + val = 1e30 if i % 2 == 0 else -1e30 + adv = np.full(64, 0.0, dtype=np.float32) + adv[32:] = val + d.loss_fn_inputs['advantages'] = types.TensorData.from_numpy(adv) + datums.append(d) + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-8-EXTREME-ADV') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-8-EXTREME-ADV') + return True + +def test_9_zero_completion(tc): + from tinker import types + datums = [] + for _ in range(4): + d = types.Datum( + model_input=types.ModelInput.from_ints(list(range(1, 65))), + loss_fn_inputs={ + 'target_tokens': [0]*64, 'weights': [0]*64, + 'logprobs': types.TensorData.from_numpy(np.zeros(64, dtype=np.float32)), + 'advantages': types.TensorData.from_numpy(np.zeros(64, dtype=np.float32)), + }, + ) + datums.append(d) + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-9-ZERO-COMPL') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-9-ZERO-COMPL') + return True + +def test_10_partial_advantages(tc): + datums = [ + make_datum(seq_len=64, completion_len=32, include_advantages=True), + make_datum(seq_len=64, completion_len=32, include_advantages=False), + make_datum(seq_len=64, completion_len=32, include_advantages=True), + make_datum(seq_len=64, completion_len=32, include_advantages=False), + ] + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-10-PARTIAL-ADV') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-10-PARTIAL-ADV') + return True + +def test_11_mixed_seq_lengths(tc): + datums = [ + make_datum(seq_len=32, completion_len=16), + make_datum(seq_len=128, completion_len=64), + make_datum(seq_len=48, completion_len=24), + make_datum(seq_len=96, completion_len=48), + ] + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-11-MIXED') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-11-MIXED') + return True + +def test_12_all_bad(tc): + datums = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=i) for i in range(4)] + _, _, elapsed = run_forward_backward(tc, datums, 'TEST-12-ALL-BAD') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-12-ALL-BAD') + return True + +def test_13_forward_only_then_train(tc): + datums_infer = [make_datum(seq_len=64, completion_len=32, include_advantages=False) for _ in range(4)] + start = time.time() + try: + tc.forward(datums_infer).result() + except Exception: + if time.time() - start >= TIMEOUT: + return False + datums_train = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums_train, 'TEST-13-TRAIN') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-13-TRAIN') + return True + +def test_14_rapid_bad_good(tc): + for i in range(5): + bad = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=i+1) for _ in range(4)] + _, _, elapsed = run_forward_backward(tc, bad, f'TEST-14-BAD-{i+1}') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, f'TEST-14-BAD-{i+1}') + good = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, good, f'TEST-14-GOOD-{i+1}') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(tc, f'TEST-14-GOOD-{i+1}') + return True + +def test_15_final_health(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-15-FINAL') + assert ok and elapsed < TIMEOUT + do_optim_step(tc, 'TEST-15-FINAL') + return True + +def test_16_large_batch(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(16)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-16-LARGE') + if elapsed >= TIMEOUT: + return False + assert ok + do_optim_step(tc, 'TEST-16-LARGE') + return True + +def test_17_single_datum(tc): + # With dp_size=2 + nproc_per_node=2, minimum batch must be >= data_world_size + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-17-SMALL') + if elapsed >= TIMEOUT: + return False + assert ok + do_optim_step(tc, 'TEST-17-SMALL') + return True + +def test_18_save_after_error(tc): + bad = [make_datum(seq_len=64, completion_len=32, bad_logprobs_len=2) for _ in range(4)] + _, _, elapsed = run_forward_backward(tc, bad, 'TEST-18-ERR') + if elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-18-ERR') + try: + tc.save_weights_for_sampler().result() + except Exception: + pass + good = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, good, 'TEST-18-POST') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-18-POST') + return True + +def test_19_consecutive_optim_steps(tc): + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-19-BASE') + assert ok and elapsed < TIMEOUT + for i in range(3): + do_optim_step(tc, f'TEST-19-STEP-{i+1}') + datums = [make_datum(seq_len=64, completion_len=32) for _ in range(4)] + ok, _, elapsed = run_forward_backward(tc, datums, 'TEST-19-VERIFY') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(tc, 'TEST-19-VERIFY') + return True + + +ALL_TESTS = [ + ('TEST-1: Normal GRPO Training', test_1_normal_grpo), + ('TEST-2: Bad old_logps (original bug)', test_2_bad_old_logps), + ('TEST-3: Recovery after error', test_3_recovery), + ('TEST-4: No advantages (zero loss)', test_4_no_advantages), + ('TEST-5: Consecutive bad batches', test_5_consecutive_bad), + ('TEST-6: NaN logprobs', test_6_nan_logprobs), + ('TEST-7: +Inf/-Inf logprobs', test_7_inf_logprobs), + ('TEST-8: Extreme advantages (1e30)', test_8_extreme_advantages), + ('TEST-9: Zero completion tokens', test_9_zero_completion), + ('TEST-10: Partial advantages (ragged)', test_10_partial_advantages), + ('TEST-11: Mixed sequence lengths', test_11_mixed_seq_lengths), + ('TEST-12: All datums bad (100%)', test_12_all_bad), + ('TEST-13: forward_only then train', test_13_forward_only_then_train), + ('TEST-14: Rapid bad->good alternation', test_14_rapid_bad_good), + ('TEST-15: Final health check', test_15_final_health), + ('TEST-16: Large batch (16 datums)', test_16_large_batch), + ('TEST-17: Single datum batch', test_17_single_datum), + ('TEST-18: Save after error', test_18_save_after_error), + ('TEST-19: Consecutive optim_steps', test_19_consecutive_optim_steps), +] + + +def main(): + log('=' * 60) + log('NCCL-Safe E2E Test - Tinker Client Path') + log('=' * 60) + log(f'Server URL: {SERVER_URL}') + log(f'Base Model: {BASE_MODEL}') + log(f'TWINKLE_FAIL_FAST = {os.getenv("TWINKLE_FAIL_FAST", "1 (default)")}') + + wait_for_server(SERVER_URL) + tc = init_client() + + results = [] + for name, test_fn in ALL_TESTS: + log(f'\n{"=" * 60}\n{name}\n{"=" * 60}') + try: + passed = test_fn(tc) + results.append((name, 'PASS' if passed else 'FAIL')) + log(f'[{name}] {"PASS" if passed else "FAIL"}') + except Exception as e: + log(f'{name}: EXCEPTION: {e}') + traceback.print_exc() + results.append((name, 'FAIL')) + + log(f'\n{"=" * 60}\nRESULTS SUMMARY\n{"=" * 60}') + all_passed = all(s == 'PASS' for _, s in results) + for name, status in results: + log(f' [{status}] {name}') + log(f'\n{"ALL" if all_passed else "SOME"} {len(results)} TESTS {"PASSED" if all_passed else "FAILED"}!') + return 0 if all_passed else 1 + + +def test_nccl_safe_tinker_e2e(): + """Pytest-collected entry point.""" + rc = main() + assert rc == 0, 'Some Tinker NCCL-safe E2E tests failed' + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/integration/test_nccl_safe_twinkle_e2e.py b/tests/server/integration/test_nccl_safe_twinkle_e2e.py new file mode 100644 index 000000000..c9cce48a6 --- /dev/null +++ b/tests/server/integration/test_nccl_safe_twinkle_e2e.py @@ -0,0 +1,345 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Real E2E test for NCCL-safe fault tolerance via Twinkle client path. + +Exercises the /twinkle/forward_backward endpoint through the Twinkle SDK +(init_twinkle_client + MultiLoraTransformersModel). This is a SEPARATE code +path from the Tinker SDK (/tinker/forward_backward). + +Prerequisites: + 1. Ray cluster running with GPUs (2 for model DP/TP, optionally 1 for sampler) + 2. Twinkle server started with TWINKLE_FAIL_FAST=0 + +Usage (direct): + python tests/server/integration/test_nccl_safe_twinkle_e2e.py + +Usage (pytest, requires TWINKLE_TEST_GPU_E2E=1): + TWINKLE_TEST_GPU_E2E=1 pytest tests/server/integration/test_nccl_safe_twinkle_e2e.py -v +""" +from __future__ import annotations + +import os +import sys +import time +import logging +import traceback +from typing import Any, Dict, List + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') +logger = logging.getLogger(__name__) + + +def log(msg): + print(f'[E2E-Twinkle] {msg}', flush=True) + + +BASE_MODEL = 'Qwen/Qwen3.5-4B' +SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://localhost:9000') +TIMEOUT = 120 +ADAPTER_NAME = 'nccl-safe-test' + + +def wait_for_server(url, timeout=300): + """Wait for Twinkle server to become ready.""" + import requests + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f'{url}/-/routes', timeout=5) + if resp.status_code == 200: + log(f'Server is ready (waited {int(time.time() - start)}s)') + return True + except Exception: + pass + time.sleep(5) + raise TimeoutError(f'Server not ready after {timeout}s') + + +def init_client(): + """Initialize Twinkle client and configure model for GRPO training.""" + from twinkle_client import init_twinkle_client + from twinkle_client.model import MultiLoraTransformersModel + from peft import LoraConfig + + init_twinkle_client(base_url=SERVER_URL, api_key='EMPTY_TOKEN') + + model = MultiLoraTransformersModel(model_id=f'ms://{BASE_MODEL}') + model.add_adapter_to_model( + adapter_name=ADAPTER_NAME, + config=LoraConfig(r=16, target_modules=['q_proj', 'v_proj']), + gradient_accumulation_steps=1, + ) + model.set_loss('GRPOLoss', init_args={'epsilon': 0.2}) + model.set_optimizer('Adam', lr=1e-5) + model.set_template('Qwen3_5Template') + model.set_processor('InputProcessor', padding_side='right') + log('Twinkle client + model configured successfully') + return model + + +def make_input_features( + batch_size=4, seq_len=64, completion_len=32, *, + bad_old_logps_len=None, include_advantages=True, + nan_old_logps=False, extreme_advantages=None, all_labels_masked=False, +): + """Construct InputFeature list + old_logps + advantages for GRPO.""" + prompt_len = seq_len - completion_len + input_features = [] + old_logps_list = [] + advantages_list = [] + + for i in range(batch_size): + input_ids = list(range(1, seq_len + 1)) + labels = [-100] * seq_len if all_labels_masked else ( + [-100] * prompt_len + list(range(100, 100 + completion_len))) + input_features.append({ + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': [1] * seq_len, + 'position_ids': list(range(seq_len)), + }) + + if bad_old_logps_len is not None: + logps = np.random.randn(bad_old_logps_len).tolist() + elif nan_old_logps: + logps = [float('nan')] * completion_len + else: + logps = np.random.randn(completion_len).tolist() + old_logps_list.append(logps) + + if extreme_advantages is not None: + advantages_list.append(extreme_advantages if i % 2 == 0 else -extreme_advantages) + else: + advantages_list.append(float(np.random.randn())) + + old_logps = old_logps_list if include_advantages else None + advantages = advantages_list if include_advantages else None + return input_features, old_logps, advantages + + +def run_forward_backward(model, inputs, old_logps, advantages, test_name): + """Run forward_backward and return (success, result, elapsed_seconds).""" + log(f'[{test_name}] Sending {len(inputs)} input features...') + start = time.time() + try: + kwargs: Dict[str, Any] = {} + if old_logps is not None: + kwargs['old_logps'] = old_logps + if advantages is not None: + kwargs['advantages'] = advantages + + result = model.forward_backward(inputs=inputs, **kwargs) + elapsed = time.time() - start + log(f'[{test_name}] Completed in {elapsed:.1f}s') + if hasattr(result, 'result') and result.result is not None: + log(f'[{test_name}] result = {result.result}') + return True, result, elapsed + except Exception as e: + elapsed = time.time() - start + log(f'[{test_name}] FAILED in {elapsed:.1f}s: {type(e).__name__}: {e}') + if elapsed > TIMEOUT: + log(f'[{test_name}] TIMEOUT! This suggests NCCL hang!') + return False, None, elapsed + + +def do_optim_step(model, test_name): + """Run clip_grad_and_step.""" + try: + model.clip_grad_and_step() + log(f'[{test_name}] clip_grad_and_step OK') + return True + except Exception as e: + log(f'[{test_name}] clip_grad_and_step FAILED: {e}') + return False + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test Scenarios (12 tests) +# ═══════════════════════════════════════════════════════════════════════════ + +def test_1_normal_grpo(m): + inputs, old_logps, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-1-NORMAL') + assert ok and elapsed < TIMEOUT + do_optim_step(m, 'TEST-1-NORMAL') + return True + +def test_2_bad_old_logps(m): + inputs, old_logps, adv = make_input_features(batch_size=4, bad_old_logps_len=5) + ok, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-2-BAD-LOGPS') + if not ok: + return elapsed < TIMEOUT + assert elapsed < TIMEOUT + do_optim_step(m, 'TEST-2-BAD-LOGPS') + return True + +def test_3_recovery(m): + inputs, old_logps, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-3-RECOVERY') + assert ok and elapsed < TIMEOUT + do_optim_step(m, 'TEST-3-RECOVERY') + return True + +def test_4_nan_old_logps(m): + inputs, old_logps, adv = make_input_features(batch_size=4, nan_old_logps=True) + _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-4-NAN') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-4-NAN') + return True + +def test_5_extreme_advantages(m): + inputs, old_logps, adv = make_input_features(batch_size=4, extreme_advantages=1e30) + _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-5-EXTREME') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-5-EXTREME') + return True + +def test_6_all_labels_masked(m): + inputs, old_logps, adv = make_input_features(batch_size=4, all_labels_masked=True) + _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-6-MASKED') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-6-MASKED') + return True + +def test_7_consecutive_bad(m): + for i in range(5): + inputs, old_logps, adv = make_input_features(batch_size=4, bad_old_logps_len=i+1) + _, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, f'TEST-7-{i+1}') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, f'TEST-7-{i+1}') + return True + +def test_8_rapid_bad_good(m): + for i in range(5): + bad_in, bad_lp, bad_adv = make_input_features(batch_size=4, bad_old_logps_len=i+1) + _, _, elapsed = run_forward_backward(m, bad_in, bad_lp, bad_adv, f'TEST-8-BAD-{i+1}') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, f'TEST-8-BAD-{i+1}') + good_in, good_lp, good_adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, good_in, good_lp, good_adv, f'TEST-8-GOOD-{i+1}') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(m, f'TEST-8-GOOD-{i+1}') + return True + +def test_9_final_health(m): + inputs, old_logps, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, inputs, old_logps, adv, 'TEST-9-FINAL') + assert ok and elapsed < TIMEOUT + do_optim_step(m, 'TEST-9-FINAL') + return True + +def test_10_gradient_accumulation_error(m): + inputs, lp, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, inputs, lp, adv, 'TEST-10-GA1') + if not ok or elapsed >= TIMEOUT: + return False + bad_in, bad_lp, bad_adv = make_input_features(batch_size=4, bad_old_logps_len=3) + _, _, elapsed = run_forward_backward(m, bad_in, bad_lp, bad_adv, 'TEST-10-GA2-BAD') + if elapsed >= TIMEOUT: + return False + inputs, lp, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, inputs, lp, adv, 'TEST-10-GA3') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-10-GA') + return True + +def test_11_forward_only_then_train(m): + inputs, _, _ = make_input_features(batch_size=4, include_advantages=False) + start = time.time() + try: + m.forward_only(inputs=inputs) + except Exception: + if time.time() - start >= TIMEOUT: + return False + train_in, lp, adv = make_input_features(batch_size=4) + ok, _, elapsed = run_forward_backward(m, train_in, lp, adv, 'TEST-11-TRAIN') + if not ok or elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-11-TRAIN') + return True + +def test_12_mixed_seq_lengths(m): + all_inputs, all_lp, all_adv = [], [], [] + for sl, cl in [(32, 16), (128, 64), (48, 24), (96, 48)]: + feats, lp, adv = make_input_features(batch_size=1, seq_len=sl, completion_len=cl) + all_inputs.extend(feats) + if lp: + all_lp.extend(lp) + if adv: + all_adv.extend(adv) + _, _, elapsed = run_forward_backward(m, all_inputs, all_lp, all_adv, 'TEST-12-MIXED') + if elapsed >= TIMEOUT: + return False + do_optim_step(m, 'TEST-12-MIXED') + return True + + +ALL_TESTS = [ + ('TEST-1: Normal GRPO Training', test_1_normal_grpo), + ('TEST-2: Bad old_logps (original bug)', test_2_bad_old_logps), + ('TEST-3: Recovery after error', test_3_recovery), + ('TEST-4: NaN old_logps', test_4_nan_old_logps), + ('TEST-5: Extreme advantages (1e30)', test_5_extreme_advantages), + ('TEST-6: All labels masked (-100)', test_6_all_labels_masked), + ('TEST-7: Consecutive bad batches', test_7_consecutive_bad), + ('TEST-8: Rapid bad->good', test_8_rapid_bad_good), + ('TEST-9: Final health check', test_9_final_health), + ('TEST-10: Gradient accumulation error', test_10_gradient_accumulation_error), + ('TEST-11: forward_only then train', test_11_forward_only_then_train), + ('TEST-12: Mixed sequence lengths', test_12_mixed_seq_lengths), +] + + +def main(): + log('=' * 60) + log('NCCL-Safe E2E Test - Twinkle Client Path') + log('=' * 60) + log(f'Server URL: {SERVER_URL}') + log(f'Base Model: {BASE_MODEL}') + log(f'TWINKLE_FAIL_FAST = {os.getenv("TWINKLE_FAIL_FAST", "1 (default)")}') + + wait_for_server(SERVER_URL) + m = init_client() + + results = [] + for name, test_fn in ALL_TESTS: + log(f'\n{"=" * 60}\n{name}\n{"=" * 60}') + try: + passed = test_fn(m) + results.append((name, 'PASS' if passed else 'FAIL')) + log(f'[{name}] {"PASS" if passed else "FAIL"}') + except Exception as e: + log(f'{name}: EXCEPTION: {e}') + traceback.print_exc() + results.append((name, 'FAIL')) + + log(f'\n{"=" * 60}\nRESULTS SUMMARY\n{"=" * 60}') + all_passed = all(s == 'PASS' for _, s in results) + for name, status in results: + log(f' [{status}] {name}') + log(f'\n{"ALL" if all_passed else "SOME"} {len(results)} TESTS {"PASSED" if all_passed else "FAILED"}!') + return 0 if all_passed else 1 + + +def test_nccl_safe_twinkle_e2e(): + """Pytest-collected entry point.""" + rc = main() + assert rc == 0, 'Some Twinkle NCCL-safe E2E tests failed' + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/integration/test_sft_e2e.py b/tests/server/integration/test_sft_e2e.py new file mode 100644 index 000000000..40c794b11 --- /dev/null +++ b/tests/server/integration/test_sft_e2e.py @@ -0,0 +1,194 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""SFT (Supervised Fine-Tuning) E2E integration tests. + +Tests SFT training across all 4 combinations: + - Twinkle client x (transformers | megatron) + - Tinker client x (transformers | megatron) + +Backend selection via env var TWINKLE_TEST_BACKEND (default: transformers). + +## How to run + + # Start server (transformers or megatron) + python tests/server/start_e2e_server.py --config tests/server/config/server_config_4b_e2e.yaml + + # Run SFT tests + TWINKLE_TEST_GPU_E2E=1 TWINKLE_TEST_BACKEND=transformers pytest tests/server/integration/test_sft_e2e.py -v +""" +from __future__ import annotations + +import os +import sys +import time + +# Ensure project root is in sys.path for both pytest and direct execution +_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_GPU_E2E', '0') != '1', + reason='Set TWINKLE_TEST_GPU_E2E=1 to run real GPU E2E tests (requires running server)', +) + +from tests.server.integration.e2e_helpers import ( + BASE_URL, + GRADIENT_ACCUMULATION_STEPS, + TIMEOUT, + assert_loss_decreases, + assert_no_timeout, + convert_tensors, + create_sft_dataset, + create_tinker_training_client, + create_twinkle_sft_model, + get_backend, + init_twinkle_client_session, + log, + wait_for_server, +) + +# ── Configuration ── +SFT_TRAIN_STEPS = 20 # 20 steps ensures enough training for both backends + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: SFT via Twinkle client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_sft_twinkle(): + """SFT training via Twinkle client (MultiLoraTransformersModel). + + Pass criteria: + - Training completes 10 steps without timeout + - Loss shows downward trend (last_3_avg < first_3_avg) + """ + backend = get_backend() + log(f'=== test_sft_twinkle [backend={backend}] ===') + + wait_for_server() + init_twinkle_client_session() + + # Setup + from twinkle.dataloader import DataLoader + + dataset = create_sft_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=4) + model = create_twinkle_sft_model() + + log(f'Dataset: {len(dataset)} samples, {len(dataloader)} batches') + log(f'Training {SFT_TRAIN_STEPS} steps (GA={GRADIENT_ACCUMULATION_STEPS})') + + # Training loop + losses = [] + for step, batch in enumerate(dataloader): + if step >= SFT_TRAIN_STEPS: + break + + t0 = time.time() + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + elapsed = time.time() - t0 + assert_no_timeout(elapsed, f'sft_twinkle step {step}') + + # Log metric every GA steps + if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: + metric = model.calculate_metric(is_training=True) + try: + loss = float(metric.result.get('loss')) if hasattr(metric.result, 'get') else float( + metric.result['loss']) + except Exception: + loss = float('nan') + losses.append(loss) + log(f'[step {step + 1}] loss={loss:.4f} ({elapsed:.1f}s)') + + # Assertions — both backends should report real loss via calculate_metric + assert len(losses) >= 4, f'Expected at least 4 logged losses, got {len(losses)}' + assert_loss_decreases(losses, 'sft_twinkle') + log(f'test_sft_twinkle PASSED (backend={backend})') + + +# ═══════════════════════════════════════════════════════════════════════════ +# Test: SFT via Tinker client +# ═══════════════════════════════════════════════════════════════════════════ + +def test_sft_tinker(): + """SFT training via Tinker client (ServiceClient + forward_backward). + + Pass criteria: + - Training completes 10 steps without timeout + - Loss shows downward trend (last_3_avg < first_3_avg) + """ + from tinker import types + from twinkle.dataloader import DataLoader + from twinkle.server.common import input_feature_to_datum + + backend = get_backend() + log(f'=== test_sft_tinker [backend={backend}] ===') + + wait_for_server() + + # Setup + dataset = create_sft_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=4) + training_client = create_tinker_training_client(rank=16) + + log(f'Dataset: {len(dataset)} samples, {len(dataloader)} batches') + log(f'Training {SFT_TRAIN_STEPS} steps') + + # Training loop + losses = [] + for step, batch in enumerate(dataloader): + if step >= SFT_TRAIN_STEPS: + break + + # Convert batch to Tinker Datums + input_datums = [input_feature_to_datum(input_feature) for input_feature in batch] + + # Forward-backward + t0 = time.time() + fwdbwd_result = training_client.forward_backward(input_datums, 'cross_entropy').result() + elapsed_fb = time.time() - t0 + assert_no_timeout(elapsed_fb, f'sft_tinker forward_backward step {step}') + + # Optimizer step + optim_result = training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result() + elapsed_total = time.time() - t0 + assert_no_timeout(elapsed_total, f'sft_tinker total step {step}') + + # Compute loss from logprobs + try: + logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) + weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datums]) + loss = float(-np.dot(logprobs, weights) / max(weights.sum(), 1e-8)) + except Exception: + loss = float('nan') + losses.append(loss) + log(f'[step {step + 1}] loss={loss:.4f} ({elapsed_total:.1f}s)') + + # Assertions + assert len(losses) >= 4, f'Expected at least 4 logged losses, got {len(losses)}' + assert_loss_decreases(losses, 'sft_tinker') + log(f'test_sft_tinker PASSED (backend={backend})') + + +# ── Direct execution ── + +def main() -> int: + log('Running SFT E2E tests directly...') + try: + test_sft_twinkle() + test_sft_tinker() + log('ALL SFT TESTS PASSED') + return 0 + except Exception as e: + log(f'FAILED: {e}') + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/server/model/test_tinker_handlers.py b/tests/server/model/test_tinker_handlers.py index 72efbaf90..a5330172f 100644 --- a/tests/server/model/test_tinker_handlers.py +++ b/tests/server/model/test_tinker_handlers.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import AsyncMock, MagicMock, patch from fastapi import FastAPI from starlette.requests import Request from tinker import types @@ -46,3 +47,81 @@ async def test_tinker_dpo_forward_backward_requires_per_dp_pairs(): assert management.scheduled[-1]['batch_size'] == 2 assert management.scheduled[-1]['data_world_size'] == 2 assert management.scheduled[-1]['batch_size_multiple'] == 2 + + +class _SaveWeightsDummyManagement: + """Dummy management that actually executes the task to test save_weights_for_sampler logic.""" + + def __init__(self): + self.model = MagicMock() + self.state = MagicMock() + self.state.get_model_metadata = AsyncMock(return_value={'base_model': 'test-model'}) + self.state.create_sampling_session = AsyncMock(return_value='session-123') + + async def _on_request_start(self, request): + return 'token1' + + def get_adapter_name(self, adapter_name=None): + return adapter_name + + def assert_resource_exists(self, adapter_name): + pass + + async def schedule_task(self, task, **kwargs): + # Actually execute the task to test response logic + return await task() + + +@pytest.mark.asyncio +@patch('twinkle.server.model.tinker_handlers.create_checkpoint_manager') +async def test_save_weights_for_sampler_path_mode_returns_path(mock_create_ckpt_mgr): + """save_weights_for_sampler(name) mode: sampling_session_seq_id is None → returns path != None.""" + mock_ckpt_mgr = MagicMock() + mock_ckpt_mgr.get_ckpt_name.return_value = 'step-1' + mock_ckpt_mgr.get_save_dir.return_value = '/tmp/save_dir' + mock_ckpt_mgr.save.return_value = 'twinkle://model1/sampler_weights/20260101_000000' + mock_create_ckpt_mgr.return_value = mock_ckpt_mgr + + management = _SaveWeightsDummyManagement() + app = FastAPI() + _register_tinker_routes(app, lambda: management) + + body = types.SaveWeightsForSamplerRequest( + model_id='model1', + path='step-1', + sampling_session_seq_id=None, # path mode + ) + + route = next(route for route in app.routes if getattr(route, 'path', None) == '/tinker/save_weights_for_sampler') + request = Request({'type': 'http', 'headers': []}) + response = await route.endpoint(request, body, management) + + assert response.path == 'twinkle://model1/sampler_weights/20260101_000000' + assert response.sampling_session_id == 'session-123' + + +@pytest.mark.asyncio +@patch('twinkle.server.model.tinker_handlers.create_checkpoint_manager') +async def test_save_weights_for_sampler_session_mode_returns_none_path(mock_create_ckpt_mgr): + """save_weights_and_get_sampling_client() mode: sampling_session_seq_id is set → returns path == None.""" + mock_ckpt_mgr = MagicMock() + mock_ckpt_mgr.get_ckpt_name.return_value = 'step-1' + mock_ckpt_mgr.get_save_dir.return_value = '/tmp/save_dir' + mock_ckpt_mgr.save.return_value = 'twinkle://model1/sampler_weights/20260101_000000' + mock_create_ckpt_mgr.return_value = mock_ckpt_mgr + + management = _SaveWeightsDummyManagement() + app = FastAPI() + _register_tinker_routes(app, lambda: management) + + body = types.SaveWeightsForSamplerRequest( + model_id='model1', + sampling_session_seq_id=0, # session mode + ) + + route = next(route for route in app.routes if getattr(route, 'path', None) == '/tinker/save_weights_for_sampler') + request = Request({'type': 'http', 'headers': []}) + response = await route.endpoint(request, body, management) + + assert response.path is None + assert response.sampling_session_id == 'session-123' diff --git a/tests/server/start_e2e_server.py b/tests/server/start_e2e_server.py new file mode 100644 index 000000000..dc2a9f694 --- /dev/null +++ b/tests/server/start_e2e_server.py @@ -0,0 +1,155 @@ +"""One-click: restart Ray cluster + launch Twinkle server + wait until ready. + +Usage: + python start_e2e_server.py # default config + python start_e2e_server.py --config my_config.yaml # custom config + python start_e2e_server.py --kill-only # just kill everything +""" +import argparse +import os +import signal +import subprocess +import sys +import time + +import requests + +# ── Paths ── +RAY = "/mnt/nas2/anaconda3/envs/tinker_myl/bin/ray" +PYTHON = "/mnt/nas2/anaconda3/envs/tinker_myl/bin/python" +WORKDIR = "/mnt/nas2/yunlin.myl/twinkle" +DEFAULT_CONFIG = "tests/server/config/server_config_4b_e2e.yaml" +RAY_TEMP_DIR = "/mnt/nas2/yunlin.myl/ray_logs" +SERVER_LOG = os.path.join(WORKDIR, "server_e2e.log") + +# ── Server check ── +SERVER_URL = "http://localhost:9000/-/routes" +READY_KEYWORD = "processor" +TIMEOUT = 600 +POLL_INTERVAL = 5 + + +def run(cmd, env=None, check=True): + """Run a shell command, print it, and return CompletedProcess.""" + full_env = os.environ.copy() + if env: + full_env.update(env) + print(f" $ {cmd}") + result = subprocess.run(cmd, shell=True, env=full_env, + capture_output=True, text=True) + if result.returncode != 0 and check: + print(f" STDERR: {result.stderr.strip()}") + return result + + +def kill_server(): + """Kill any existing twinkle.server processes.""" + print("[1/4] Killing existing Twinkle server processes...") + result = subprocess.run( + "ps aux | grep 'twinkle.server' | grep -v grep | awk '{print $2}'", + shell=True, capture_output=True, text=True + ) + pids = result.stdout.strip().split() + for pid in pids: + try: + os.kill(int(pid), signal.SIGKILL) + print(f" Killed PID {pid}") + except (ProcessLookupError, ValueError): + pass + if not pids: + print(" No existing server found") + time.sleep(2) + + +def restart_ray(): + """Stop and restart the Ray cluster (5 GPU + 1 CPU nodes).""" + print("[2/4] Restarting Ray cluster...") + run(f"{RAY} stop --force", check=False) + time.sleep(2) + + # Head node: GPU 0,1,2,3 (4 GPUs for model PP=2 x DP=2) + run(f"{RAY} start --head --port=6379 --num-gpus=4 " + f"--disable-usage-stats --temp-dir={RAY_TEMP_DIR}", + env={"CUDA_VISIBLE_DEVICES": "0,1,2,3"}) + + # Worker: GPU 4 (1 GPU for sampler) + run(f"{RAY} start --address=127.0.0.1:6379 --num-gpus=1", + env={"CUDA_VISIBLE_DEVICES": "4"}) + + # CPU-only worker (processor + server) + run(f"{RAY} start --address=127.0.0.1:6379 --num-gpus=0", + env={"CUDA_VISIBLE_DEVICES": ""}) + + print(" Ray cluster started (4+1+0 GPUs)") + + +def launch_server(config: str): + """Launch the Twinkle server in the background.""" + print(f"[3/4] Launching Twinkle server (config={config})...") + config_path = os.path.join(WORKDIR, config) + if not os.path.isfile(config_path): + print(f" ERROR: config not found: {config_path}", file=sys.stderr) + sys.exit(1) + + cmd = f"{PYTHON} -m twinkle.server launch --config {config_path}" + log_fd = open(SERVER_LOG, "w") + subprocess.Popen( + cmd, shell=True, cwd=WORKDIR, stdout=log_fd, stderr=log_fd, + env={**os.environ, "TWINKLE_TRUST_REMOTE_CODE": "1"}, + start_new_session=True, + ) + print(f" Server starting (log: {SERVER_LOG})") + + +def wait_ready(): + """Poll until the server is fully ready.""" + print(f"[4/4] Waiting for server to be ready (timeout={TIMEOUT}s)...") + start = time.time() + while time.time() - start < TIMEOUT: + try: + resp = requests.get(SERVER_URL, timeout=3) + if resp.ok and READY_KEYWORD in resp.text: + elapsed = time.time() - start + print(f" Server READY ({elapsed:.0f}s)") + return True + except (requests.ConnectionError, requests.Timeout): + pass + time.sleep(POLL_INTERVAL) + + print(f" TIMEOUT: server not ready after {TIMEOUT}s", file=sys.stderr) + print(f" Check log: tail -50 {SERVER_LOG}", file=sys.stderr) + return False + + +def main(): + parser = argparse.ArgumentParser(description="Restart Ray + launch Twinkle server") + parser.add_argument("--config", default=DEFAULT_CONFIG, + help=f"Server config yaml (default: {DEFAULT_CONFIG})") + parser.add_argument("--kill-only", action="store_true", + help="Only kill server + Ray, don't restart") + parser.add_argument("--no-ray", action="store_true", + help="Skip Ray restart (server only)") + args = parser.parse_args() + + kill_server() + + if args.kill_only: + run(f"{RAY} stop --force", check=False) + print("Done (kill-only)") + return 0 + + if not args.no_ray: + restart_ray() + + launch_server(args.config) + + if wait_ready(): + print("\n✓ All set. Run your test:") + print(f" TWINKLE_TEST_GPU_E2E=1 {PYTHON} -u tests/server/integration/test_full_cycle_e2e.py") + return 0 + else: + return 1 + + +if __name__ == "__main__": + sys.exit(main())