diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
index 77441f8f858..462de6afa65 100644
--- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
+++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
@@ -110,6 +110,25 @@ def parse_args() -> argparse.Namespace:
"--trust_remote_code", action="store_true", help="Trust remote code for HF models."
)
parser.add_argument("--tp", type=int, default=None, help="Tensor parallel size.")
+ parser.add_argument(
+ "--block-size",
+ type=int,
+ default=None,
+ help="KV cache block size. Some models require a specific value — e.g. MiniMax-M3's "
+ "MSA sparse attention mandates 128. Default (None) lets vLLM choose.",
+ )
+ parser.add_argument(
+ "--language-model-only",
+ action="store_true",
+ help="Skip the vision encoder for text-only dumps (multimodal models, e.g. MiniMax-M3).",
+ )
+ parser.add_argument(
+ "--enforce-eager",
+ action="store_true",
+ help="Disable CUDA graph / torch.compile. Needed for MiniMax-M3: its MSA sparse "
+ "kernel JIT-recompiles per shape and a recompile can exceed the executor RPC "
+ "timeout under cudagraph capture, hanging the engine.",
+ )
parser.add_argument(
"--debug-max-num-conversations", type=int, default=None, help="Limit conversations."
)
@@ -168,7 +187,12 @@ def keep_conversation(entry):
# Resolve the aux-layer indices and append the final-layer output. vLLM saves the
# final (un-normed) hidden state when ``num_hidden_layers`` is passed as a layer id.
config = AutoConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
- num_hidden_layers = getattr(config, "num_hidden_layers", None)
+ # Vision-language / wrapped configs (e.g. MiniMax-M3's MiniMaxM3VLConfig) nest the
+ # text model's layer count under text_config / llm_config rather than at the top level.
+ text_config = getattr(config, "text_config", None) or getattr(config, "llm_config", None)
+ num_hidden_layers = getattr(config, "num_hidden_layers", None) or getattr(
+ text_config, "num_hidden_layers", None
+ )
if num_hidden_layers is None:
raise ValueError(f"model config has no 'num_hidden_layers' attribute: {config}")
aux_layer_ids = _resolve_aux_layers_standalone(
@@ -244,12 +268,23 @@ def keep_conversation(entry):
storage_path.mkdir(parents=True, exist_ok=True)
atexit.register(shutil.rmtree, storage_path, ignore_errors=True)
+ # Model-specific extras (e.g. MiniMax-M3 mandates block_size=128 for MSA sparse
+ # attention; --language-model-only skips the vision encoder for text-only dumps).
+ extra_llm_kwargs = {}
+ if args.block_size is not None:
+ extra_llm_kwargs["block_size"] = args.block_size
+ if args.language_model_only:
+ extra_llm_kwargs["language_model_only"] = True
+ if args.enforce_eager:
+ extra_llm_kwargs["enforce_eager"] = True
+
llm = LLM(
model=args.model,
tensor_parallel_size=tp,
max_model_len=args.max_seq_len,
trust_remote_code=args.trust_remote_code,
enable_chunked_prefill=False, # required by extract_hidden_states
+ **extra_llm_kwargs,
# With prefix caching on, vLLM serves shared prefixes from cache in block-sized
# chunks and the hidden-state connector only emits the freshly-computed suffix, so
# the dumped hidden_states come out short by N*block_size vs the full input_ids /
diff --git a/examples/speculative_decoding/distributed_generate/worker.sh b/examples/speculative_decoding/distributed_generate/worker.sh
index 97bf14c014a..01f01bb746b 100644
--- a/examples/speculative_decoding/distributed_generate/worker.sh
+++ b/examples/speculative_decoding/distributed_generate/worker.sh
@@ -20,10 +20,15 @@ BACKEND="$2"
JOBS_PER_NODE="$3"
SYSTEM_PROMPT="$4"
+# Optional model-specific serve flags via env, appended to the serve command. E.g. for
+# MiniMax-M3: VLLM_SERVE_EXTRA_ARGS="--block-size 128 --language-model-only" (--block-size
+# 128 is mandatory for M3's MSA sparse attention; --language-model-only skips the vision
+# encoder for text-only synthesis; KV cache stays bf16 — M3's MSA fused kernel rejects
+# fp8 KV).
if [ "$BACKEND" == "vllm" ]; then
- vllm serve /model/ --tensor-parallel-size 8 --served-model-name model --port 8000 --host 0.0.0.0 --trust-remote-code &
+ vllm serve /model/ --tensor-parallel-size 8 --served-model-name model --port 8000 --host 0.0.0.0 --trust-remote-code ${VLLM_SERVE_EXTRA_ARGS:-} &
else
- python3 -m sglang.launch_server --model-path /model --served-model-name model --tp 8 --port 8000 --host 0.0.0.0 --trust-remote-code &
+ python3 -m sglang.launch_server --model-path /model --served-model-name model --tp 8 --port 8000 --host 0.0.0.0 --trust-remote-code ${SGLANG_SERVE_EXTRA_ARGS:-} &
fi
# Wait for server to start up by polling the health endpoint
echo "Waiting for server to start..."
@@ -59,6 +64,11 @@ if [ "$mpi_rank" -eq 0 ]; then
if [ -n "$SYSTEM_PROMPT" ]; then
cmd+=" --system_prompt $SYSTEM_PROMPT"
fi
+ # Optional: cycle thinking modes for a mixed dataset (e.g. MiniMax-M3
+ # THINKING_MODES="enabled,disabled,adaptive").
+ if [ -n "${THINKING_MODES:-}" ]; then
+ cmd+=" --thinking-modes $THINKING_MODES"
+ fi
echo "Running command: $cmd"
eval $cmd
done
diff --git a/examples/speculative_decoding/scripts/server_generate.py b/examples/speculative_decoding/scripts/server_generate.py
index 0fb71a0a0a1..6490055edc7 100644
--- a/examples/speculative_decoding/scripts/server_generate.py
+++ b/examples/speculative_decoding/scripts/server_generate.py
@@ -54,7 +54,29 @@
"--log_empty_conversations", action="store_true", help="Log empty conversations"
)
parser.add_argument("--system_prompt", nargs="+", type=str, default="", help="System prompt")
+parser.add_argument(
+ "--thinking-modes",
+ type=str,
+ default="",
+ help="Comma-separated thinking modes to cycle through per conversation, passed to the "
+ "server via chat_template_kwargs (e.g. 'enabled,disabled,adaptive' for MiniMax-M3). "
+ "Conversation i uses modes[i %% len(modes)], giving an even mix across the dataset. "
+ "Empty (default) sends no thinking_mode, preserving behavior for models without it.",
+)
+parser.add_argument(
+ "--output-format",
+ type=str,
+ default="oai",
+ choices=["oai", "sharegpt"],
+ help="Output chat format: 'oai' writes the OpenAI standard ({'messages': [{role, "
+ "content}, ...]}); 'sharegpt' writes the legacy {'conversations': [...]} key. Both "
+ "use role/content message dicts.",
+)
args = parser.parse_args()
+MESSAGES_KEY = "messages" if args.output_format == "oai" else "conversations"
+
+# Parse the thinking-mode cycle; empty -> no thinking_mode injected.
+THINKING_MODES = [m.strip() for m in args.thinking_modes.split(",") if m.strip()]
if args.data_path.endswith("jsonl"):
@@ -73,6 +95,14 @@ def generate_data(messages, idx, system_prompt):
try:
model_name = args.model
+ # Cycle thinking modes per conversation for an even mix across the dataset (e.g.
+ # MiniMax-M3 enabled/disabled/adaptive). Passed via chat_template_kwargs; empty
+ # list -> not sent.
+ thinking_mode = THINKING_MODES[idx % len(THINKING_MODES)] if THINKING_MODES else None
+ extra_body = (
+ {"chat_template_kwargs": {"thinking_mode": thinking_mode}} if thinking_mode else {}
+ )
+
if args.chat:
output_messages = []
@@ -105,6 +135,7 @@ def generate_data(messages, idx, system_prompt):
messages=output_messages,
max_tokens=args.max_tokens,
temperature=args.temperature,
+ extra_body=extra_body,
)
if response.choices[0].finish_reason == "length":
break
@@ -123,7 +154,9 @@ def generate_data(messages, idx, system_prompt):
return
to_write = {"conversation_id": idx}
else:
- to_write = {"conversation_id": idx, "conversations": output_messages}
+ to_write = {"conversation_id": idx, MESSAGES_KEY: output_messages}
+ if thinking_mode:
+ to_write["thinking_mode"] = thinking_mode
with open(args.output_path, "a") as f:
# write in share gpt format
f.write(json.dumps(to_write) + "\n")
@@ -187,7 +220,17 @@ def generate_data(messages, idx, system_prompt):
for idx, sample in enumerate(data):
if idx in finished_ids:
continue
- future = executor.submit(generate_data, sample["conversations"], idx, system_prompt)
+ # Accept both ShareGPT ("conversations") and OAI-chat ("messages") prompt datasets
+ # (e.g. Speculative-Decoding-Dataset-v2 uses "messages"). generate_data already
+ # handles the from/value and role/content message shapes.
+ sample_messages = sample.get("conversations")
+ if sample_messages is None:
+ sample_messages = sample.get("messages")
+ if sample_messages is None:
+ raise KeyError(
+ f"sample {idx} has neither 'conversations' nor 'messages'; keys: {list(sample)}"
+ )
+ future = executor.submit(generate_data, sample_messages, idx, system_prompt)
futures.append(future)
for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
diff --git a/tools/launcher/examples/MiniMax/MiniMax-M3-DFlash/chat_template_train.jinja b/tools/launcher/examples/MiniMax/MiniMax-M3-DFlash/chat_template_train.jinja
new file mode 100644
index 00000000000..8d2e37cae3e
--- /dev/null
+++ b/tools/launcher/examples/MiniMax/MiniMax-M3-DFlash/chat_template_train.jinja
@@ -0,0 +1,256 @@
+{# MiniMax-M3 chat template with {% generation %} tags for answer_only_loss training.
+ Adapted from https://huggingface.co/MiniMaxAI/MiniMax-M3/blob/main/chat_template.jinja
+ with {% generation %} / {% endgeneration %} wrapping the assistant turn's output
+ (think + content + tool_calls), matching the MiniMax-M2.7-DFlash convention: the
+ ']~b]ai\n' header and the trailing eos sit OUTSIDE the generation span, so the loss
+ mask covers only what the model produces. Thinking-mode handling is preserved verbatim
+ so dumps reflect the same enabled/disabled/adaptive mix used during synthesis.
+#}
+{# ---------- special token variables ---------- #}
+{%- set ns_token = ']<]minimax[>[' -%}
+{%- set bod_token = ']~!b[' -%}
+{%- set bos_token = ']~b]' -%}
+{%- set eos_token = '[e~[' -%}
+{%- set toolcall_begin_token = ns_token ~ '' -%}
+{%- set toolcall_end_token = ns_token ~ '' -%}
+{%- set think_begin_token = '' -%}
+{%- set think_end_token = '' -%}
+{%- set image_token = ']<]image[>[' -%}
+{%- set video_token = ']<]video[>[' -%}
+{#- Thinking mode: "enabled" / "disabled" / "adaptive" / not defined -#}
+{#- Recursive XML renderer for tool_call arguments ======================== -#}
+{#- None values are intentionally skipped in mapping iteration so that
+ `null` (which would round-trip to the literal string "null")
+ never appears in the rendered tool_call. The convention is: omit the
+ field entirely. The top-level `_args` loop applies the same rule.
+ The `val is none` branch below is a safety net only — upstream cleaning
+ (drop_none_in_tool_arguments) should ensure no None ever reaches here. -#}
+{%- macro to_xml(val, ns) -%}
+{%- if val is mapping -%}
+{%- for k, v in val.items() if v is not none -%}
+{{ ns }}<{{ k }}>{{ to_xml(v, ns) }}{{ ns }}{{ k }}>
+{%- endfor -%}
+{%- elif val is iterable and val is not string -%}
+{%- for item in val -%}
+{{ ns }}- {{ to_xml(item, ns) }}{{ ns }}
+{%- endfor -%}
+{%- elif val is none -%}
+{#- Should be unreachable when upstream cleaning is applied. -#}
+{%- elif val is boolean -%}
+{{ val | tojson }}
+{%- else -%}
+{{ val }}
+{%- endif -%}
+{%- endmacro -%}
+{#- Tool Rendering Functions ============================================== -#}
+{%- macro render_tool_namespace(namespace_name, tool_list) -%}
+{%- for tool in tool_list -%}
+{{ tool.function | tojson(ensure_ascii=False) }}
+{% endfor -%}
+{%- endmacro -%}
+{%- macro visible_text(content) -%}
+ {%- if content is string -%}
+ {{ content }}
+ {%- elif content is iterable and content is not mapping -%}
+ {%- for item in content -%}
+ {%- if item is mapping and item.type == 'text' -%}
+ {{- item.text }}
+ {%- elif item is mapping and item.type == 'image' -%}
+ {{- image_token }}
+ {%- elif item is mapping and item.type == 'video' -%}
+ {{- video_token}}
+ {%- elif item is string -%}
+ {{- item }}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- elif content is none -%}
+ {{- '' }}
+ {%- else -%}
+ {{- content }}
+ {%- endif -%}
+{%- endmacro -%}
+{#- System Message Construction ============================================ -#}
+{%- macro build_system_message(system_message) -%}
+ {%- if system_message and system_message.content -%}
+ {{- visible_text(system_message.content) }}
+ {%- else -%}
+ {{- 'Your model version is MiniMax-M3, developed by MiniMax. Knowledge cutoff: January 2026. Founded in early 2022, MiniMax is a global AI foundation model company committed to advancing the frontiers of AI towards AGI.' }}
+ {%- endif -%}
+
+ {#- Thinking mode instructions -#}
+ {{- '\n\n\n' }}
+ {{- 'You have a thinking capability that allows you to reason step by step before responding. When thinking is enabled, wrap your reasoning in ' ~ think_begin_token ~ think_end_token ~ ' tags before your response. When thinking is disabled, begin your response directly after the ' ~ think_end_token ~ ' prefix. When thinking is adaptive, decide on your own whether to think for the current turn.\n' }}
+ {%- if thinking_mode is defined -%}
+ {%- if thinking_mode == "enabled" -%}
+ {{- 'Current thinking mode: enabled. You MUST think step by step before every response, including after receiving function/tool results.\n' }}
+ {%- elif thinking_mode == "disabled" -%}
+ {{- 'Current thinking mode: disabled. Do not output any thinking process.\n' }}
+ {%- elif thinking_mode == "adaptive" -%}
+ {{- 'Current thinking mode: adaptive. You are encouraged to think for complex decision-making, multi-step reasoning, or when analyzing function/tool results.\n' }}
+ {%- endif -%}
+ {%- else -%}
+ {{- 'Current thinking mode: adaptive. You are encouraged to think for complex decision-making, multi-step reasoning, or when analyzing function/tool results.\n' }}
+ {%- endif -%}
+ {{- '' }}
+{%- endmacro -%}
+{%- macro build_developer_message(developer_message) -%}
+ {%- if developer_message and developer_message.content -%}
+ {{- visible_text(developer_message.content) }}
+ {%- else -%}
+ {%- if model_identity is not defined -%}
+ {%- set model_identity = "You are a helpful assistant." -%}
+ {%- endif -%}
+ {{- model_identity }}
+ {%- endif -%}
+{%- endmacro -%}
+{#- Main Template Logic ================================================= -#}
+{#- Role mapping: root -> system sp (high priority), system/developer -> developer sp (low priority) -#}
+{%- set system_message = none -%}
+{%- set developer_message = none -%}
+{%- set conversation_messages = messages -%}
+{%- if messages and messages[0].role == "root" -%}
+ {%- set system_message = messages[0] -%}
+ {%- set conversation_messages = messages[1:] -%}
+ {%- if conversation_messages and conversation_messages[0].role in ["system", "developer"] -%}
+ {%- set developer_message = conversation_messages[0] -%}
+ {%- set conversation_messages = conversation_messages[1:] -%}
+ {%- endif -%}
+{%- elif messages and messages[0].role in ["system", "developer"] -%}
+ {%- set developer_message = messages[0] -%}
+ {%- set conversation_messages = messages[1:] -%}
+{%- endif -%}
+{#- Render system sp (higher priority, root role only) -#}
+{{- bod_token ~ bos_token ~ 'system' ~ '\n' }}
+{{- build_system_message(system_message) }}
+{{- eos_token ~ '\n' }}
+
+{#- Render developer sp (lower priority: system/developer role + tools) -#}
+{{- bos_token ~ 'developer' ~ '\n' }}
+{{- build_developer_message(developer_message) }}
+{%- if tools -%}
+ {{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }}
+ {{- '\n' ~ '' ~ '\n' }}
+ {{- render_tool_namespace("functions", tools) }}
+ {{- '' ~ '\n\n' }}
+ {{- 'To call tools, wrap all invocations in a single ' ~ toolcall_begin_token ~ toolcall_end_token ~ ' block. Parameter values containing nested objects or arrays are recursively expanded into XML elements. Example:\n' }}
+ {{- '\n' ~ toolcall_begin_token ~ '\n' }}
+ {{- ns_token + '' }}
+ {{- ns_token + 'value-1' + ns_token + '' }}
+ {{- ns_token + '' }}
+ {{- ns_token + '- ' }}
+ {{- ns_token + 'val-a' + ns_token + '' }}
+ {{- ns_token + 'val-b' + ns_token + '' }}
+ {{- ns_token + '
' }}
+ {{- ns_token + '' }}
+ {{- ns_token + '\n' }}
+ {{- ns_token + '' }}
+ {{- ns_token + 'value-1' + ns_token + '' }}
+ {{- ns_token + '\n' }}
+ {{- toolcall_end_token }}
+{%- endif -%}
+{{- eos_token ~ '\n' }}
+
+{#- Render messages -#}
+{%- set last_tool_call = namespace(name=none) -%}
+{%- for message in conversation_messages -%}
+ {%- if message.role == 'assistant' -%}
+ {{- bos_token ~ 'ai' ~ '\n' }}
+ {%- generation -%}
+ {%- set reasoning_content = '' %}
+ {%- set content = visible_text(message.content) %}
+ {%- if message.reasoning_content is string %}
+ {%- set reasoning_content = message.reasoning_content %}
+ {%- else %}
+ {%- if think_end_token in content %}
+ {%- set reasoning_content = content.split(think_end_token)[0].strip('\n').split(think_begin_token)[-1].strip('\n') %}
+ {%- set content = content.split(think_end_token)[-1].strip('\n') %}
+ {%- endif %}
+ {%- endif %}
+
+ {%- if reasoning_content -%}
+ {#- Render thinking for every assistant turn (all-turn visible) -#}
+ {{- think_begin_token ~ reasoning_content ~ think_end_token }}
+ {%- else -%}
+ {#- No thinking rendered → prefix with think_end_token -#}
+ {{- think_end_token }}
+ {%- endif -%}
+
+ {%- if content -%}
+ {{- content }}
+ {%- endif -%}
+ {%- if message.tool_calls -%}
+ {{- toolcall_begin_token ~ '\n' }}
+
+ {%- for tool_call in message.tool_calls -%}
+ {%- if tool_call.function -%}
+ {%- set tool_call = tool_call.function -%}
+ {%- endif -%}
+{{- ns_token + '' }}
+{%- set _args = tool_call.arguments -%}
+{%- for k, v in _args.items() if v is not none %}
+{{- ns_token + '<' + k + '>' -}}
+{{- to_xml(v, ns_token) -}}
+{{- ns_token + '' + k + '>' }}
+{%- endfor -%}
+{{- ns_token + '' ~ '\n' }}
+ {%- endfor -%}
+
+ {{- toolcall_end_token }}
+ {%- if message.tool_calls[-1].function -%}
+ {%- set last_tool_call.name = message.tool_calls[-1].function.name -%}
+ {%- else -%}
+ {%- set last_tool_call.name = message.tool_calls[-1].name -%}
+ {%- endif -%}
+ {%- else -%}
+ {%- set last_tool_call.name = none -%}
+ {%- endif -%}
+ {%- endgeneration -%}
+ {{- eos_token ~ '\n' }}
+
+ {%- elif message.role == 'tool' -%}
+ {%- if last_tool_call.name is none -%}
+ {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
+ {%- endif -%}
+ {%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%}
+ {{- bos_token ~ 'tool' }}
+ {%- endif -%}
+ {{- '\n' }}
+ {%- if message.content is string -%}
+ {{- message.content }}
+ {%- else -%}
+ {%- for tr in message.content -%}
+ {%- if tr is mapping and tr.type is defined and tr.type == 'image' -%}
+ {{- image_token }}
+ {%- elif tr is mapping and tr.type is defined and tr.type == 'video' -%}
+ {{- video_token }}
+ {%- else -%}
+ {{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- endif -%}
+ {{- '' }}
+ {%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%}
+ {{- eos_token ~ '\n' -}}
+ {%- endif -%}
+
+ {%- elif message.role == 'user' -%}
+ {{- bos_token ~ 'user' ~ '\n' }}
+ {{- visible_text(message.content) }}
+ {{- eos_token ~ '\n' }}
+ {%- endif -%}
+{%- endfor -%}
+
+{#- Generation prompt -#}
+{%- if add_generation_prompt -%}
+{{- bos_token ~ 'ai' ~ '\n' }}
+{%- if thinking_mode is defined and thinking_mode == "disabled" -%}
+ {{- think_end_token }}
+{%- elif thinking_mode is defined and thinking_mode == "adaptive" -%}
+ {#- adaptive: no prefix, let model decide -#}
+{%- elif thinking_mode is defined and thinking_mode == "enabled" -%}
+ {#- enabled or not defined: default to think -#}
+ {{- think_begin_token }}
+{%- else -%}
+ {#- adaptive: no prefix, let model decide -#}
+{%- endif -%}
+{%- endif -%}
diff --git a/tools/launcher/examples/MiniMax/MiniMax-M3-DFlash/hf_offline_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M3-DFlash/hf_offline_dflash.yaml
new file mode 100644
index 00000000000..bc848231cc2
--- /dev/null
+++ b/tools/launcher/examples/MiniMax/MiniMax-M3-DFlash/hf_offline_dflash.yaml
@@ -0,0 +1,125 @@
+# DFlash offline speculative decoding training for MiniMax-M3 (427B VL-MoE, 26B active).
+#
+# 2-step pipeline (mirrors MiniMax-M2.7-DFlash/hf_offline_dflash.yaml). Offline is the
+# chosen path for M3 — online FSDP2 training streams the 427B base forward at every step
+# and is too slow at scale:
+# task_0: Dump base-model hidden states once via vLLM extract_hidden_states.
+# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head +
+# embed_tokens, not the full 427B base).
+#
+# M3-specific notes (differ from M2.7), all validated 2026-06-22:
+# * Dump serves MiniMax-M3-MXFP8 (NVIDIA-published quant) single-node TP8 on H100. M3
+# is not in stable vLLM yet -> image vllm/vllm-openai:minimax-m3.
+# * --block-size 128 is MANDATORY for M3's MSA sparse attention.
+# * --language-model-only skips the vision encoder (text-only synth/dump).
+# * --enforce-eager + VLLM_RPC_TIMEOUT=1800000 are REQUIRED: M3's MSA Triton kernel
+# (_gqa_sparse_fwd_kernel) JIT-recompiles per input shape; under cudagraph capture a
+# recompile blows the executor RPC timeout (sample_tokens timeout -> EngineDead hang).
+# Eager mode + a long RPC timeout avoids the hang. KV cache stays bf16 (M3's MSA fused
+# kernel rejects fp8 KV).
+# * Training FakeBaseModel reads lm_head + embed_tokens from the bf16 M3 (real weights;
+# these tensors are not what MXFP8 quantizes, so dump@MXFP8 / train@bf16 logits stay
+# consistent). Per Ye Yu: adhere to published bf16/MXFP8 ckpts, do not self-quantize.
+# * Sequence length 8192 (not M2.7's 4096) end-to-end: synth, dump, training — captures
+# full reasoning across the enabled/disabled/adaptive mode mix.
+#
+# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036)
+#
+# Usage:
+# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M3-DFlash/hf_offline_dflash.yaml --yes
+
+job_name: MiniMax-M3-DFlash_offline
+pipeline:
+ global_vars:
+ # bf16 base — used by training's FakeBaseModel (lm_head + embed_tokens) and tokenizer.
+ hf_model: /hf-local/MiniMaxAI/MiniMax-M3
+ # NVIDIA-published MXFP8 quant — used only to serve the dump single-node TP8 on H100.
+ dump_model: /hf-local/MiniMaxAI/MiniMax-M3-MXFP8
+
+ # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=8, MXFP8).
+ task_0:
+ script: common/eagle3/dump_offline_data_vllm.sh
+ args:
+ # Synthetic data from the M3 synth campaign (default.jsonl, 3-way thinking-mode mix),
+ # cleaned + uploaded. Update the suffix once the cleaned set is published.
+ - --input-data /hf-local/modelopt/MiniMax-M3-synthetic-data
+ - --output-dir /scratchspace/dflash_minimax_m3_hidden_states
+ # Must match the draft model's num_hidden_layers (recipe default: 5).
+ - --aux-layers dflash
+ - --answer-only-loss
+ - --chat-template examples/MiniMax/MiniMax-M3-DFlash/chat_template_train.jinja
+ - --max-seq-len 8192
+ - --tp 8
+ # M3 MSA requirements (see header).
+ - --block-size 128
+ - --language-model-only
+ - --enforce-eager
+ environment:
+ - HF_MODEL_CKPT: <>
+ - TRUST_REMOTE_CODE: "1"
+ # Survive MSA Triton-kernel JIT recompiles without an executor RPC timeout.
+ - VLLM_RPC_TIMEOUT: "1800000"
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 1
+ ntasks_per_node: 1
+ gpus_per_node: 8
+ container: vllm/vllm-openai:minimax-m3
+
+ # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids loading
+ # the full 427B — only lm_head + embed_tokens are read from the bf16 checkpoint.
+ task_1:
+ script: common/specdec/dflash_online_training.sh
+ args:
+ - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml
+ - model.model_name_or_path=<>
+ - model.trust_remote_code=true
+ - model.use_fake_base_for_offline=true
+ - data.mode=offline
+ - data.offline_data_path=/scratchspace/dflash_minimax_m3_hidden_states
+ - data.chat_template=examples/MiniMax/MiniMax-M3-DFlash/chat_template_train.jinja
+ - training.output_dir=/scratchspace/dflash_minimax_m3_offline
+ - training.num_train_epochs=10
+ # bs=1 @ 8192 keeps the activation footprint of M2.7's bs=2 @ 4096 (bs*seqlen equal).
+ - training.per_device_train_batch_size=1
+ - training.learning_rate=1.2e-3
+ - training.warmup_steps=100
+ - training.training_seq_len=8192
+ - training.logging_steps=100
+ - training.save_steps=400
+ - training.disable_tqdm=true
+ - training.dp_shard_size=1
+ - training.answer_only_loss=true
+ - training.ddp_timeout=3600
+ - training.bf16=false
+ - dflash.dflash_self_logit_distillation=true
+ - dflash.dflash_block_size=8
+ - dflash.dflash_num_anchors=512
+ - dflash.dflash_loss_decay_factor=4.0
+ - dflash.dflash_architecture_config.num_hidden_layers=5
+ # Mask token id: in M3, 200054 is a real special token, so the first unused reserved
+ # embedding row is 200061 (M2.7 used 200054).
+ - dflash.dflash_mask_token_id=200061
+ # YaRN rope_scaling injected at EXPORT time only (config.json field; draft weights
+ # unchanged) -> tunable per export. original_max_position_embeddings = training_seq_len
+ # (8192). factor 128 -> 8192*128 = 1048576 = M3's full 1M context. (Use factor 24 ->
+ # 196608 to match M2.7's served target instead.)
+ - dflash.dflash_export_rope_scaling.type=yarn
+ - dflash.dflash_export_rope_scaling.factor=128.0
+ - dflash.dflash_export_rope_scaling.original_max_position_embeddings=8192
+ - dflash.dflash_export_rope_scaling.beta_fast=1.0
+ - dflash.dflash_export_rope_scaling.beta_slow=1.0
+ - dflash.dflash_export_rope_scaling.mscale=1.0
+ - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0
+ environment:
+ - NUM_NODES: "8"
+ # Offline training uses a lightweight FakeBaseModel, so plain DDP suffices (no
+ # ACCELERATE_CONFIG / FSDP2 patches). OVERRIDE_TRANSFORMERS pins 4.52.4 for the
+ # MiniMax-M3 config load.
+ - OVERRIDE_TRANSFORMERS: "4.52.4"
+ - MIXED_PRECISION: "no"
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 8
+ ntasks_per_node: 1
+ gpus_per_node: 8