From 3dba6ec125485573a6b39848c7f840115892ea66 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Thu, 18 Jun 2026 16:21:22 +0800 Subject: [PATCH 1/6] add: dmd2 for flux2-klein-base-4B --- diffsynth/diffusion/__init__.py | 1 + diffsynth/diffusion/dmd2.py | 764 ++++++++++++++++++ diffsynth/pipelines/flux2_image.py | 126 ++- .../special/dmd2/FLUX.2-klein-base-4B-DMD2.sh | 29 + examples/flux2/model_training/train.py | 28 +- .../FLUX.2-klein-base-4B-DMD2.py | 20 + 6 files changed, 952 insertions(+), 16 deletions(-) create mode 100644 diffsynth/diffusion/dmd2.py create mode 100644 examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh create mode 100644 examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py diff --git a/diffsynth/diffusion/__init__.py b/diffsynth/diffusion/__init__.py index 7823637e2..d285482de 100644 --- a/diffsynth/diffusion/__init__.py +++ b/diffsynth/diffusion/__init__.py @@ -4,3 +4,4 @@ from .runner import launch_training_task, launch_data_process_task from .parsers import * from .loss import * +from .dmd2 import * diff --git a/diffsynth/diffusion/dmd2.py b/diffsynth/diffusion/dmd2.py new file mode 100644 index 000000000..f5d9e2cd5 --- /dev/null +++ b/diffsynth/diffusion/dmd2.py @@ -0,0 +1,764 @@ +import copy +from dataclasses import dataclass +from typing import Optional +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from tqdm import tqdm +from .runner import get_optimizer_class, initialize_deepspeed_gradient_checkpointing + + +def _parse_int_list(value): + if value is None or value == "": + return None + return [int(i) for i in value.split(",") if i != ""] + + +def _parse_float_list(value): + if value is None or value == "": + return None + return [float(i) for i in value.split(",") if i != ""] + + +@dataclass +class DMD2Config: + student_update_freq: int = 5 + student_sample_steps: int = 4 + student_sample_type: str = "sde" + student_schedule: str = "uniform" + student_t_list: Optional[list[float]] = None + matching_t_min: float = 0.001 + matching_t_max: float = 0.999 + matching_t_sampling: str = "uniform" + matching_t_mean: float = 0.0 + matching_t_std: float = 1.0 + gan_loss_weight: float = 0.03 + gan_r1_reg_weight: float = 0.0 + gan_r1_reg_alpha: float = 0.1 + gan_logit_reg_weight: float = 0.0 + fake_score_learning_rate: Optional[float] = None + discriminator_learning_rate: Optional[float] = None + feature_indices: Optional[list[int]] = None + discriminator_hidden_dim: Optional[int] = None + discriminator_num_blocks: Optional[int] = None + teacher_cfg_scale: Optional[float] = None + student_grad_clip_norm: Optional[float] = 10.0 + + @classmethod + def from_args(cls, args): + return cls( + student_update_freq=args.dmd2_student_update_freq, + student_sample_steps=args.dmd2_student_sample_steps, + student_sample_type=args.dmd2_student_sample_type, + student_schedule=args.dmd2_student_schedule, + student_t_list=_parse_float_list(args.dmd2_student_t_list), + matching_t_min=args.dmd2_matching_t_min, + matching_t_max=args.dmd2_matching_t_max, + matching_t_sampling=args.dmd2_matching_t_sampling, + matching_t_mean=args.dmd2_matching_t_mean, + matching_t_std=args.dmd2_matching_t_std, + gan_loss_weight=args.dmd2_gan_loss_weight, + gan_r1_reg_weight=args.dmd2_gan_r1_reg_weight, + gan_r1_reg_alpha=args.dmd2_gan_r1_reg_alpha, + gan_logit_reg_weight=args.dmd2_gan_logit_reg_weight, + fake_score_learning_rate=args.dmd2_fake_score_learning_rate, + discriminator_learning_rate=args.dmd2_discriminator_learning_rate, + feature_indices=_parse_int_list(args.dmd2_feature_indices), + discriminator_hidden_dim=args.dmd2_discriminator_hidden_dim, + discriminator_num_blocks=args.dmd2_discriminator_num_blocks, + teacher_cfg_scale=args.dmd2_teacher_cfg_scale, + student_grad_clip_norm=args.dmd2_student_grad_clip_norm, + ) + + +def add_dmd2_config(parser): + parser.add_argument("--dmd2_student_update_freq", type=int, default=5, help="Update student once every N DMD2 iterations.") + parser.add_argument("--dmd2_student_sample_steps", type=int, default=4, help="Number of distilled student sampling steps.") + parser.add_argument("--dmd2_student_sample_type", type=str, default="sde", choices=["sde", "ode"], help="Student sampling type used by the DMD2 objective.") + parser.add_argument("--dmd2_student_schedule", type=str, default="uniform", choices=["uniform"], help="Student sigma schedule.") + parser.add_argument("--dmd2_student_t_list", type=str, default=None, help="Optional student sigma schedule, including the final 0.") + parser.add_argument("--dmd2_matching_t_min", type=float, default=0.001, help="Minimum matching sigma sampled for DMD2.") + parser.add_argument("--dmd2_matching_t_max", type=float, default=0.999, help="Maximum matching sigma sampled for DMD2.") + parser.add_argument("--dmd2_matching_t_sampling", type=str, default="uniform", choices=["uniform", "logitnormal"], help="Sample matching sigma.") + parser.add_argument("--dmd2_matching_t_mean", type=float, default=0.0, help="Mean for logitnormal matching timestep sampling.") + parser.add_argument("--dmd2_matching_t_std", type=float, default=1.0, help="Std for logitnormal matching timestep sampling.") + parser.add_argument("--dmd2_gan_loss_weight", type=float, default=0.03, help="Generator GAN loss weight.") + parser.add_argument("--dmd2_gan_r1_reg_weight", type=float, default=0.0, help="Approximate R1 regularization weight for the discriminator.") + parser.add_argument("--dmd2_gan_r1_reg_alpha", type=float, default=0.1, help="Noise scale for approximate R1 regularization.") + parser.add_argument("--dmd2_gan_logit_reg_weight", type=float, default=0.0, help="L2 regularization weight on discriminator logits.") + parser.add_argument("--dmd2_fake_score_learning_rate", type=float, default=None, help="Learning rate for the fake score model.") + parser.add_argument("--dmd2_discriminator_learning_rate", type=float, default=None, help="Learning rate for the discriminator.") + parser.add_argument("--dmd2_feature_indices", type=str, default=None, help="DiT block indices used by the discriminator.") + parser.add_argument("--dmd2_discriminator_hidden_dim", type=int, default=None, help="Hidden dimension of DiT features.") + parser.add_argument("--dmd2_discriminator_num_blocks", type=int, default=None, help="Total Flux block count.") + parser.add_argument("--dmd2_teacher_cfg_scale", type=float, default=None, help="CFG scale applied to the teacher x0 in DMD2.") + parser.add_argument("--dmd2_student_grad_clip_norm", type=float, default=10.0, help="Clip student gradients to this norm.") + return parser + + +def _get_optimal_groups(num_channels): + if num_channels <= 32: + groups = max(1, num_channels // 4) + else: + groups = 32 + while groups > 1 and num_channels % groups != 0: + groups -= 1 + assert num_channels % groups == 0, f"{num_channels} not divisible by {groups}" + return groups + + +class FluxDMD2Discriminator(torch.nn.Module): + def __init__(self, feature_indices=None, num_blocks=None, inner_dim=None): + super().__init__() + if num_blocks is None: + raise ValueError("`num_blocks` must be provided.") + if inner_dim is None: + raise ValueError("`inner_dim` must be provided.") + if feature_indices is None: + feature_indices = [int(num_blocks // 2)] + self.feature_indices = sorted({int(i) for i in feature_indices if 0 <= int(i) < num_blocks}) + if len(self.feature_indices) == 0: + raise ValueError("DMD2 discriminator requires at least one valid feature index.") + self.num_features = len(self.feature_indices) + self.inner_dim = inner_dim + + hidden_channels = inner_dim // 2 + self.heads = torch.nn.ModuleList([ + torch.nn.Sequential( + torch.nn.Conv2d(inner_dim, hidden_channels, kernel_size=4, stride=2, padding=1), + torch.nn.GroupNorm(_get_optimal_groups(hidden_channels), hidden_channels), + torch.nn.LeakyReLU(0.2), + torch.nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, padding=0), + torch.nn.AdaptiveAvgPool2d((1, 1)), + torch.nn.Flatten(), + ) + for _ in self.feature_indices + ]) + + def forward(self, feats): + if not isinstance(feats, list) or len(feats) != self.num_features: + raise ValueError(f"Expected list of {self.num_features} feature tensors, got {type(feats)} with length {len(feats) if isinstance(feats, list) else 'N/A'}.") + logits = [] + for head, feat in zip(self.heads, feats): + param = next(head.parameters()) + feat = feat.to(device=param.device, dtype=param.dtype) + logits.append(head(feat)) + return torch.cat(logits, dim=1) + + +def _infer_dit_num_blocks(dit, default=40): + if hasattr(dit, "blocks") and hasattr(dit, "single_blocks"): + return len(dit.blocks) + len(dit.single_blocks) + if hasattr(dit, "transformer_blocks") and hasattr(dit, "single_transformer_blocks"): + return len(dit.transformer_blocks) + len(dit.single_transformer_blocks) + return default + + +def _infer_dit_hidden_dim(dit, default=3072): + return int(getattr(dit, "inner_dim", default)) + + +def setup_dmd2_training(module, pipe, config: DMD2Config): + module.dmd2_config = config + module.teacher_dit = copy.deepcopy(pipe.dit).eval().requires_grad_(False) + module.fake_score = copy.deepcopy(pipe.dit).train().requires_grad_(True) + module.discriminator = None + if config.gan_loss_weight > 0: + discriminator_num_blocks = config.discriminator_num_blocks or _infer_dit_num_blocks(pipe.dit) + discriminator_hidden_dim = config.discriminator_hidden_dim or _infer_dit_hidden_dim(pipe.dit) + module.discriminator = FluxDMD2Discriminator( + feature_indices=config.feature_indices, + num_blocks=discriminator_num_blocks, + inner_dim=discriminator_hidden_dim, + ) + module.dmd2_loss = DMD2Loss(config) + module._dmd2_student_param_names = {name for name, param in pipe.dit.named_parameters() if param.requires_grad} + module._dmd2_fake_score_param_names = {name for name, _ in module.fake_score.named_parameters()} + + def input_processor(inputs_shared): + return prepare_dmd2_pipeline_inputs(inputs_shared, config) + + def loss_fn(pipe, inputs_shared, inputs_posi, inputs_nega, iteration=None, **kwargs): + return module.dmd2_loss( + module, + (inputs_shared, inputs_posi, inputs_nega), + 0 if iteration is None else iteration, + ) + + def state_dict_exporter(state_dict, remove_prefix=None): + return export_dmd2_trainable_state_dict(module, state_dict, remove_prefix=remove_prefix) + + module.task_to_input_processor["dmd2"] = input_processor + module.task_to_loss["dmd2"] = loss_fn + module.task_to_state_dict_exporter["dmd2"] = state_dict_exporter + + +def prepare_dmd2_pipeline_inputs(inputs_shared, config: DMD2Config): + if _cfg_enabled(config.teacher_cfg_scale): + inputs_shared["cfg_scale"] = config.teacher_cfg_scale + return inputs_shared + + +def export_dmd2_trainable_state_dict(module, state_dict, remove_prefix=None): + student_names = {"pipe.dit." + name for name in module._dmd2_student_param_names} + state_dict = {name: param for name, param in state_dict.items() if name in student_names} + if remove_prefix is not None: + state_dict = {name[len(remove_prefix):] if name.startswith(remove_prefix) else name: param for name, param in state_dict.items()} + return state_dict + + +def set_dmd2_train_phase(module, student_phase: bool): + module.pipe.dit.train(student_phase) + for name, param in module.pipe.dit.named_parameters(): + param.requires_grad = student_phase and name in module._dmd2_student_param_names + + module.fake_score.train(not student_phase) + for name, param in module.fake_score.named_parameters(): + param.requires_grad = (not student_phase) and name in module._dmd2_fake_score_param_names + + if module.discriminator is not None: + module.discriminator.train(not student_phase) + module.discriminator.requires_grad_(not student_phase) + + +def _expand_like(value, target, dtype=None): + if dtype is None: + dtype = target.dtype + if not isinstance(value, torch.Tensor): + value = torch.tensor(value, device=target.device, dtype=dtype) + value = value.to(device=target.device, dtype=dtype) + while value.ndim < target.ndim: + value = value.view(*value.shape, 1) + return value + + +def _flow_to_x0(latents, flow, sigma): + original_dtype = latents.dtype + latents = latents.to(torch.float64) + flow = flow.to(torch.float64) + sigma = _expand_like(sigma, latents) + return (latents - sigma * flow).to(original_dtype) + + +def _forward_process(x0, eps, sigma): + original_dtype = x0.dtype + x0 = x0.to(torch.float64) + eps = eps.to(torch.float64) + sigma = _expand_like(sigma, x0) + return ((1 - sigma) * x0 + sigma * eps).to(original_dtype) + + +def _scale_noise(noise, sigma): + original_dtype = noise.dtype + noise = noise.to(torch.float64) + sigma = _expand_like(sigma, noise) + return (noise * sigma).to(original_dtype) + + +def make_dmd2_student_schedule( + pipe, + num_steps, + device, + student_schedule="uniform", + student_t_list=None, + matching_t_max=0.999, +): + time_dtype = torch.float64 + if student_t_list is not None: + sigmas = torch.tensor(student_t_list, device=device, dtype=time_dtype) + if sigmas[-1].item() != 0: + raise ValueError("`dmd2_student_t_list` must include a final 0.") + if len(sigmas) != num_steps + 1: + raise ValueError("The student sigma schedule length must equal `student_sample_steps + 1`.") + timesteps = sigmas * pipe.scheduler.num_train_timesteps + return sigmas, timesteps + if student_schedule == "uniform": + sigma_start = min(float(matching_t_max), 0.999) + sigmas = torch.linspace(sigma_start, 0.0, num_steps + 1, device=device, dtype=time_dtype) + timesteps = sigmas * pipe.scheduler.num_train_timesteps + return sigmas, timesteps + raise ValueError(f"Unsupported DMD2 student schedule: {student_schedule}") + + +def _variational_score_distillation_loss(gen_data, teacher_x0, fake_score_x0): + dims = tuple(range(1, teacher_x0.ndim)) + with torch.no_grad(): + weight = 1 / ((gen_data.float() - teacher_x0.float()).abs().mean(dim=dims, keepdim=True) + 1e-6) + weight = weight.to(dtype=gen_data.dtype) + pseudo_target = gen_data - (fake_score_x0 - teacher_x0) * weight + loss = 0.5 * F.mse_loss(gen_data.float(), pseudo_target.float(), reduction="mean") + return loss + + +def _mean_abs_by_sample(value): + dims = tuple(range(1, value.ndim)) + return value.detach().float().abs().mean(dim=dims).mean() + + +def _gan_loss_generator(fake_logits): + assert fake_logits.ndim == 2, f"fake_logits has shape {fake_logits.shape}" + gan_loss = F.softplus(-fake_logits).mean() + return gan_loss + + +def _gan_loss_discriminator(real_logits, fake_logits): + assert fake_logits.ndim == 2, f"fake_logits has shape {fake_logits.shape}" + assert real_logits.ndim == 2, f"real_logits has shape {real_logits.shape}" + gan_loss = F.softplus(fake_logits).mean() + F.softplus(-real_logits).mean() + return gan_loss + + +def _cfg_enabled(cfg_scale): + return cfg_scale is not None and abs(float(cfg_scale) - 1.0) > 1e-12 + + +class DMD2Loss: + def __init__(self, config: DMD2Config): + self.config = config + self._last_teacher_cfg_delta = None + + def _model_forward_x0( + self, + module, + dit, + latents, + timestep, + sigma, + inputs_shared, + inputs_posi, + ): + pipe = module.pipe + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + models["dit"] = dit + shared = dict(inputs_shared) + posi = dict(inputs_posi) + shared["latents"] = latents + flow = pipe.model_fn( + **models, + **shared, + **posi, + timestep=timestep, + progress_id=0, + num_inference_steps=1, + ) + return _flow_to_x0(latents, flow, sigma) + + def _model_forward_features( + self, + module, + dit, + latents, + timestep, + inputs_shared, + inputs_posi, + ): + if module.discriminator is None: + raise ValueError("DMD2 feature extraction requires a discriminator.") + pipe = module.pipe + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + models["dit"] = dit + shared = dict(inputs_shared) + posi = dict(inputs_posi) + shared["latents"] = latents + return pipe.model_fn( + **models, + **shared, + **posi, + timestep=timestep, + progress_id=0, + num_inference_steps=1, + feature_indices=set(module.discriminator.feature_indices), + return_features=True, + ) + + def _teacher_forward_x0( + self, + module, + latents, + timestep, + sigma, + inputs_shared, + inputs_posi, + inputs_nega, + ): + teacher_x0_pos = self._model_forward_x0( + module, + module.teacher_dit, + latents, + timestep, + sigma, + inputs_shared, + inputs_posi, + ) + if not _cfg_enabled(self.config.teacher_cfg_scale): + self._last_teacher_cfg_delta = torch.zeros((), device=latents.device, dtype=torch.float32) + return teacher_x0_pos + + x0_neg = self._model_forward_x0( + module, + module.teacher_dit, + latents, + timestep, + sigma, + inputs_shared, + inputs_nega, + ) + x0 = x0_neg + float(self.config.teacher_cfg_scale) * (teacher_x0_pos - x0_neg) + self._last_teacher_cfg_delta = _mean_abs_by_sample(x0 - teacher_x0_pos) + return x0 + + def _teacher_forward_features( + self, + module, + latents, + timestep, + inputs_shared, + inputs_posi, + ): + return self._model_forward_features( + module, + module.teacher_dit, + latents, + timestep, + inputs_shared, + inputs_posi, + ) + + def _sample_matching_timestep(self, pipe, device, dtype, batch_size=1, inputs_shared=None, real_data=None): + if self.config.matching_t_min > self.config.matching_t_max: + raise ValueError("`dmd2_matching_t_min` must be <= `dmd2_matching_t_max`.") + time_dtype = torch.float64 + if self.config.matching_t_sampling == "uniform": + sigma = torch.rand(batch_size, device=device, dtype=time_dtype) + sigma = sigma * (self.config.matching_t_max - self.config.matching_t_min) + self.config.matching_t_min + timestep = sigma * pipe.scheduler.num_train_timesteps + return timestep, sigma + if self.config.matching_t_sampling == "logitnormal": + sigma = torch.sigmoid( + torch.randn(batch_size, device=device, dtype=time_dtype) * self.config.matching_t_std + + self.config.matching_t_mean + ) + sigma = sigma * (self.config.matching_t_max - self.config.matching_t_min) + self.config.matching_t_min + sigma = sigma.clamp(self.config.matching_t_min, self.config.matching_t_max) + timestep = sigma * pipe.scheduler.num_train_timesteps + return timestep, sigma + raise ValueError(f"Unsupported DMD2 matching timestep sampling: {self.config.matching_t_sampling}") + + def _generate_student_data(self, module, real_data, inputs_shared, inputs_posi): + pipe = module.pipe + device, dtype = real_data.device, real_data.dtype + batch_size = real_data.shape[0] + student_sigmas, student_timesteps = make_dmd2_student_schedule( + pipe, + self.config.student_sample_steps, + device, + student_schedule=self.config.student_schedule, + student_t_list=self.config.student_t_list, + matching_t_max=self.config.matching_t_max, + ) + if len(student_sigmas) != self.config.student_sample_steps + 1: + raise ValueError("The student sigma schedule length must equal `student_sample_steps + 1`.") + + if self.config.student_sample_steps == 1: + timestep = student_timesteps[0:1].expand(batch_size) + sigma = student_sigmas[0:1].expand(batch_size) + input_student = _scale_noise(torch.randn_like(real_data), sigma) + else: + step_id = torch.randint(0, self.config.student_sample_steps, (batch_size,), device=device) + sigma = student_sigmas[step_id] + timestep = student_timesteps[step_id] + eps_student = torch.randn_like(real_data) + input_student = _forward_process(real_data, eps_student, sigma) + gen_data = self._model_forward_x0( + module, + module.pipe.dit, + input_student, + timestep, + sigma, + inputs_shared, + inputs_posi, + ) + return gen_data, input_student + + def _compute_real_feat(self, module, real_data, timestep, sigma, eps, inputs_shared, inputs_posi): + real_timestep, real_sigma, real_eps = timestep, sigma, eps + perturbed_real = _forward_process(real_data, real_eps, real_sigma) + real_feat = self._model_forward_features( + module, + module.teacher_dit, + perturbed_real, + real_timestep, + inputs_shared, + inputs_posi, + ) + return real_feat, real_timestep, real_sigma + + def _student_update_step(self, module, real_data, inputs_shared, inputs_posi, inputs_nega): + gen_data, input_student = self._generate_student_data(module, real_data, inputs_shared, inputs_posi) + timestep, sigma = self._sample_matching_timestep( + module.pipe, + real_data.device, + real_data.dtype, + real_data.shape[0], + inputs_shared=inputs_shared, + real_data=real_data, + ) + eps = torch.randn_like(real_data) + perturbed_data = _forward_process(gen_data, eps, sigma) + + with torch.no_grad(): + fake_score_x0 = self._model_forward_x0( + module, + module.fake_score, + perturbed_data, + timestep, + sigma, + inputs_shared, + inputs_posi, + ) + + if self.config.gan_loss_weight > 0: + with torch.no_grad(): + teacher_x0 = self._teacher_forward_x0( + module, perturbed_data, timestep, sigma, inputs_shared, inputs_posi, inputs_nega + ) + fake_feat = self._teacher_forward_features( + module, + perturbed_data, + timestep, + inputs_shared, + inputs_posi, + ) + fake_logits_gen = module.discriminator(fake_feat) + gan_loss_gen = _gan_loss_generator(fake_logits_gen) + else: + with torch.no_grad(): + teacher_x0 = self._teacher_forward_x0( + module, perturbed_data, timestep, sigma, inputs_shared, inputs_posi, inputs_nega + ) + gan_loss_gen = torch.zeros((), device=real_data.device, dtype=torch.float32) + + vsd_loss = _variational_score_distillation_loss(gen_data, teacher_x0.detach(), fake_score_x0) + gan_loss_weighted = self.config.gan_loss_weight * gan_loss_gen + loss = vsd_loss + gan_loss_weighted + + with torch.no_grad(): + teacher_delta = _mean_abs_by_sample(gen_data - teacher_x0) + fake_delta = _mean_abs_by_sample(gen_data - fake_score_x0) + vsd_delta = _mean_abs_by_sample(fake_score_x0 - teacher_x0) + effective_gan_weight = gan_loss_weighted.detach() / gan_loss_gen.detach().clamp_min(1e-12) + + return { + "total_loss": loss, + "vsd_loss": vsd_loss.detach(), + "gan_loss_gen": gan_loss_gen.detach(), + "gan_loss_gen_weighted": gan_loss_weighted.detach(), + "gan_loss_effective_weight": effective_gan_weight, + "dmd2_teacher_delta": teacher_delta, + "dmd2_teacher_cfg_delta": self._last_teacher_cfg_delta.detach(), + "dmd2_fake_score_delta": fake_delta, + "dmd2_vsd_delta": vsd_delta, + "dmd2_sigma_mean": sigma.detach().float().mean(), + "student_input_mean": input_student.detach().float().mean(), + } + + def _fake_score_discriminator_update_step(self, module, real_data, inputs_shared, inputs_posi, inputs_nega): + with torch.no_grad(): + gen_data, _ = self._generate_student_data(module, real_data, inputs_shared, inputs_posi) + timestep, sigma = self._sample_matching_timestep( + module.pipe, + real_data.device, + real_data.dtype, + real_data.shape[0], + inputs_shared=inputs_shared, + real_data=real_data, + ) + eps = torch.randn_like(real_data) + x_t_sg = _forward_process(gen_data, eps, sigma) + + fake_score_x0 = self._model_forward_x0( + module, + module.fake_score, + x_t_sg, + timestep, + sigma, + inputs_shared, + inputs_posi, + ) + fake_score_loss = F.mse_loss(fake_score_x0.float(), gen_data.float(), reduction="mean") + with torch.no_grad(): + fake_score_delta = _mean_abs_by_sample(fake_score_x0 - gen_data) + + gan_loss_disc = torch.zeros_like(fake_score_loss) + gan_loss_ar1 = torch.zeros_like(fake_score_loss) + gan_loss_logit_reg = torch.zeros_like(fake_score_loss) + real_logit_mean = torch.zeros_like(fake_score_loss) + fake_logit_mean = torch.zeros_like(fake_score_loss) + if self.config.gan_loss_weight > 0: + with torch.no_grad(): + fake_feat = self._model_forward_features( + module, + module.teacher_dit, + x_t_sg, + timestep, + inputs_shared, + inputs_posi, + ) + real_feat, real_timestep, real_sigma = self._compute_real_feat( + module, real_data, timestep, sigma, eps, inputs_shared, inputs_posi + ) + real_logits = module.discriminator(real_feat) + fake_logits = module.discriminator(fake_feat) + real_logit_mean = real_logits.detach().float().mean() + fake_logit_mean = fake_logits.detach().float().mean() + gan_loss_disc = _gan_loss_discriminator(real_logits, fake_logits) + if self.config.gan_logit_reg_weight > 0: + gan_loss_logit_reg = 0.5 * (real_logits.float().square().mean() + fake_logits.float().square().mean()) + if self.config.gan_r1_reg_weight > 0: + perturbed_real_alpha = real_data + self.config.gan_r1_reg_alpha * torch.randn_like(real_data) + with torch.no_grad(): + real_feat_alpha = self._model_forward_features( + module, + module.teacher_dit, + perturbed_real_alpha, + real_timestep, + inputs_shared, + inputs_posi, + ) + real_logits_alpha = module.discriminator(real_feat_alpha) + gan_loss_ar1 = F.mse_loss(real_logits, real_logits_alpha, reduction="mean") + + loss = ( + fake_score_loss + + gan_loss_disc + + self.config.gan_r1_reg_weight * gan_loss_ar1 + + self.config.gan_logit_reg_weight * gan_loss_logit_reg + ) + return { + "total_loss": loss, + "fake_score_loss": fake_score_loss.detach(), + "fake_score_delta": fake_score_delta, + "gan_loss_disc": gan_loss_disc.detach(), + "gan_loss_ar1": gan_loss_ar1.detach(), + "gan_loss_logit_reg": gan_loss_logit_reg.detach(), + "gan_loss_logit_reg_weighted": (self.config.gan_logit_reg_weight * gan_loss_logit_reg).detach(), + "gan_real_logit": real_logit_mean, + "gan_fake_logit": fake_logit_mean, + "dmd2_sigma_mean": sigma.detach().float().mean(), + } + + def __call__(self, module, inputs, iteration): + inputs_shared, inputs_posi, inputs_nega = inputs + real_data = inputs_shared.get("input_latents") + if real_data is None: + raise ValueError("DMD2 requires image latents from the dataset. Please provide training images.") + student_phase = iteration % self.config.student_update_freq == 0 + set_dmd2_train_phase(module, student_phase) + if student_phase: + return self._student_update_step(module, real_data, inputs_shared, inputs_posi, inputs_nega) + return self._fake_score_discriminator_update_step(module, real_data, inputs_shared, inputs_posi, inputs_nega) + + +def _trainable_params(module): + return [param for param in module.parameters() if param.requires_grad] + + +def _dmd2_current_optimizers(config, optimizers, iteration): + if iteration % config.student_update_freq == 0: + return [optimizers["student"]], [optimizers["student_scheduler"]] + current_optimizers = [optimizers["fake_score"]] + current_schedulers = [optimizers["fake_score_scheduler"]] + if "discriminator" in optimizers: + current_optimizers.append(optimizers["discriminator"]) + current_schedulers.append(optimizers["discriminator_scheduler"]) + return current_optimizers, current_schedulers + + +def launch_dmd2_training_task( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model, + model_logger, + learning_rate: float = 1e-5, + weight_decay: float = 1e-2, + num_workers: int = 1, + save_steps: int = None, + num_epochs: int = 1, + customized_optimizer: str = None, + args=None, + **kwargs, +): + if args is not None: + learning_rate = args.learning_rate + weight_decay = args.weight_decay + num_workers = args.dataset_num_workers + save_steps = args.save_steps + num_epochs = args.num_epochs + customized_optimizer = args.customized_optimizer + + optimizer_class = get_optimizer_class(customized_optimizer) + config = model.dmd2_config + fake_score_lr = config.fake_score_learning_rate or learning_rate + discriminator_lr = config.discriminator_learning_rate or learning_rate + + student_optimizer = optimizer_class(_trainable_params(model.pipe.dit), lr=learning_rate, weight_decay=weight_decay) + fake_score_optimizer = optimizer_class(model.fake_score.parameters(), lr=fake_score_lr, weight_decay=weight_decay) + student_scheduler = torch.optim.lr_scheduler.ConstantLR(student_optimizer, factor=1.0, total_iters=1) + fake_score_scheduler = torch.optim.lr_scheduler.ConstantLR(fake_score_optimizer, factor=1.0, total_iters=1) + + optimizers = { + "student": student_optimizer, + "fake_score": fake_score_optimizer, + "student_scheduler": student_scheduler, + "fake_score_scheduler": fake_score_scheduler, + } + prepare_items = [model, student_optimizer, fake_score_optimizer, student_scheduler, fake_score_scheduler] + if model.discriminator is not None: + discriminator_optimizer = optimizer_class(model.discriminator.parameters(), lr=discriminator_lr, weight_decay=weight_decay) + discriminator_scheduler = torch.optim.lr_scheduler.ConstantLR(discriminator_optimizer, factor=1.0, total_iters=1) + optimizers["discriminator"] = discriminator_optimizer + optimizers["discriminator_scheduler"] = discriminator_scheduler + prepare_items.extend([discriminator_optimizer, discriminator_scheduler]) + + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + prepare_items.append(dataloader) + model.to(device=accelerator.device) + prepared = accelerator.prepare(*prepare_items) + + model = prepared[0] + prepared_tail = list(prepared[1:]) + optimizers["student"] = prepared_tail.pop(0) + optimizers["fake_score"] = prepared_tail.pop(0) + optimizers["student_scheduler"] = prepared_tail.pop(0) + optimizers["fake_score_scheduler"] = prepared_tail.pop(0) + if model.discriminator is not None: + optimizers["discriminator"] = prepared_tail.pop(0) + optimizers["discriminator_scheduler"] = prepared_tail.pop(0) + dataloader = prepared_tail.pop(0) + + initialize_deepspeed_gradient_checkpointing(accelerator) + iteration = 0 + for epoch_id in range(num_epochs): + for data in tqdm(dataloader): + with accelerator.accumulate(model): + if dataset.load_from_cache: + loss_map = model({}, inputs=data, iteration=iteration) + else: + loss_map = model(data, iteration=iteration) + loss = loss_map["total_loss"] + current_optimizers, current_schedulers = _dmd2_current_optimizers(config, optimizers, iteration) + accelerator.backward(loss) + if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0: + accelerator.clip_grad_norm_(_trainable_params(model.pipe.dit), config.student_grad_clip_norm) + for optimizer in current_optimizers: + optimizer.step() + for scheduler in current_schedulers: + scheduler.step() + for optimizer in current_optimizers: + optimizer.zero_grad(set_to_none=True) + model_logger.on_step_end(accelerator, model, save_steps, loss=loss, metrics=loss_map) + iteration += 1 + if save_steps is None: + model_logger.on_epoch_end(accelerator, model, epoch_id) + + model_logger.on_training_end(accelerator, model, save_steps) \ No newline at end of file diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 7807ff970..26f028955 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -616,29 +616,129 @@ def model_fn_flux2( extra_text_embedding=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, + feature_indices=None, + return_features=False, **kwargs, ): + feature_indices = set() if feature_indices is None else set(feature_indices) image_seq_len = latents.shape[1] if edit_latents is not None: image_seq_len = latents.shape[1] latents = torch.concat([latents, edit_latents], dim=1) image_ids = torch.concat([image_ids, edit_image_ids], dim=1) - embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) + if embedded_guidance is None: + embedded_guidance = None + elif isinstance(embedded_guidance, torch.Tensor): + embedded_guidance = embedded_guidance.to(device=latents.device, dtype=latents.dtype).flatten() + if embedded_guidance.numel() == 1: + embedded_guidance = embedded_guidance.expand(latents.shape[0]) + elif embedded_guidance.numel() != latents.shape[0]: + raise ValueError("`embedded_guidance` must be a scalar or match the latent batch size.") + else: + embedded_guidance = torch.full((latents.shape[0],), float(embedded_guidance), device=latents.device, dtype=latents.dtype) if extra_text_embedding is not None: extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device) extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1]) prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1) text_ids = torch.concat([text_ids, extra_text_ids], dim=1) - model_output = dit( - hidden_states=latents, - timestep=timestep / 1000, - guidance=embedded_guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=image_ids, - kv_cache=kv_cache, - use_gradient_checkpointing=use_gradient_checkpointing, - use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + if not return_features: + model_output = dit( + hidden_states=latents, + timestep=timestep / 1000, + guidance=embedded_guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=image_ids, + kv_cache=kv_cache, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + model_output = model_output[:, :image_seq_len] + return model_output + + height, width = kwargs.get("height"), kwargs.get("width") + if height is not None and width is not None: + feature_height, feature_width = int(height) // 16, int(width) // 16 + else: + feature_height = int(math.sqrt(image_seq_len)) + feature_width = image_seq_len // feature_height if feature_height > 0 else 0 + if feature_height * feature_width != image_seq_len: + raise ValueError("Flux2 feature extraction requires height/width or square latent tokens.") + + features = [] + + def append_feature(feat): + feat = feat[:, :image_seq_len] + batch_size, _, channels = feat.shape + feat = feat.permute(0, 2, 1).reshape(batch_size, channels, feature_height, feature_width) + features.append(feat) + if len(features) == len(feature_indices): + return features + return None + + num_txt_tokens = prompt_embeds.shape[1] + timestep = timestep.to(latents.dtype) + guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000 + temb = dit.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = dit.double_stream_modulation_img(temb) + double_stream_mod_txt = dit.double_stream_modulation_txt(temb) + single_stream_mod = dit.single_stream_modulation(temb)[0] + + hidden_states = dit.x_embedder(latents) + encoder_hidden_states = dit.context_embedder(prompt_embeds) + + if image_ids.ndim == 3: + image_ids = image_ids[0] + if text_ids.ndim == 3: + text_ids = text_ids[0] + + image_rotary_emb = dit.pos_embed(image_ids) + text_rotary_emb = dit.pos_embed(text_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), ) - model_output = model_output[:, :image_seq_len] - return model_output + + for block_id, block in enumerate(dit.transformer_blocks): + encoder_hidden_states, hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=None, + kv_cache=None if kv_cache is None else kv_cache.get(f"double_{block_id}"), + ) + if block_id in feature_indices: + selected_features = append_feature(hidden_states) + if selected_features is not None: + return selected_features + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + num_double_blocks = len(dit.transformer_blocks) + + for block_id, block in enumerate(dit.single_transformer_blocks): + hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=None, + kv_cache=None if kv_cache is None else kv_cache.get(f"single_{block_id}"), + ) + feature_id = block_id + num_double_blocks + if feature_id in feature_indices: + selected_features = append_feature(hidden_states[:, num_txt_tokens:num_txt_tokens + image_seq_len]) + if selected_features is not None: + return selected_features + + if len(features) != len(feature_indices): + raise ValueError(f"Only collected {len(features)} feature maps for {len(feature_indices)} requested feature indices.") + return features diff --git a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh new file mode 100644 index 000000000..c74d329a7 --- /dev/null +++ b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh @@ -0,0 +1,29 @@ +accelerate launch --num_processes 1 --num_machines 1 --mixed_precision bf16 --dynamo_backend no examples/flux2/model_training/train.py \ + --dataset_base_path /mnt/nas1/yejinyan.yjy/diffsynth-study/DiffSynth-Studio/data/OmniData/midjourney-v6-520k-raw-diffsynth \ + --dataset_metadata_path /mnt/nas1/yejinyan.yjy/diffsynth-study/DiffSynth-Studio/data/OmniData/midjourney-v6-520k-raw-diffsynth/metadata.csv \ + --height 512 \ + --width 512 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-5 \ + --dmd2_fake_score_learning_rate 1e-5 \ + --dmd2_discriminator_learning_rate 1e-5 \ + --num_epochs 1 \ + --save_steps 1000 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-base-4B_dmd2-4steps" \ + --trainable_models "dit" \ + --task dmd2 \ + --dmd2_student_sample_steps 4 \ + --dmd2_student_sample_type sde \ + --dmd2_student_schedule uniform \ + --dmd2_student_update_freq 5 \ + --dmd2_matching_t_sampling uniform \ + --dmd2_gan_loss_weight 0.03 \ + --dmd2_matching_t_min 0.001 \ + --dmd2_matching_t_max 0.999 \ + --dmd2_feature_indices 12 \ + --embedded_guidance 4 \ + --dmd2_teacher_cfg_scale 4 \ + --use_gradient_checkpointing \ No newline at end of file diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 8e32eefc4..69baba46d 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -18,9 +18,11 @@ def __init__( extra_inputs=None, fp8_models=None, offload_models=None, + embedded_guidance=1.0, template_model_id_or_path=None, resume_from_checkpoint=None, remove_prefix_in_ckpt=None, enable_lora_hot_loading=False, + dmd2_config=None, device="cpu", task="sft", ): @@ -47,7 +49,10 @@ def __init__( self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.fp8_models = fp8_models + self.embedded_guidance = embedded_guidance self.task = task + self.dmd2_config = None + self.task_to_input_processor = {} self.task_to_loss = { "sft:data_process": lambda pipe, *args: args, "direct_distill:data_process": lambda pipe, *args: args, @@ -56,6 +61,9 @@ def __init__( "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), } + self.task_to_state_dict_exporter = {} + if task == "dmd2": + setup_dmd2_training(self, self.pipe, dmd2_config) def get_pipeline_inputs(self, data): inputs_posi = {"prompt": data["prompt"]} @@ -68,29 +76,38 @@ def get_pipeline_inputs(self, data): "width": data["image"].size[0], # Please do not modify the following parameters # unless you clearly know what this will cause. - "embedded_guidance": 1.0, + "embedded_guidance": self.embedded_guidance, "cfg_scale": 1, "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, } + if self.task in self.task_to_input_processor: + inputs_shared = self.task_to_input_processor[self.task](inputs_shared) inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) return inputs_shared, inputs_posi, inputs_nega - def forward(self, data, inputs=None): + def forward(self, data, inputs=None, iteration=None): if inputs is None: inputs = self.get_pipeline_inputs(data) inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) for unit in self.pipe.units: inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) - loss = self.task_to_loss[self.task](self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs, iteration=iteration) return loss + def export_trainable_state_dict(self, state_dict, remove_prefix=None): + if self.task in self.task_to_state_dict_exporter: + return self.task_to_state_dict_exporter[self.task](state_dict, remove_prefix=remove_prefix) + return super().export_trainable_state_dict(state_dict, remove_prefix=remove_prefix) + def flux2_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = add_general_config(parser) parser = add_image_size_config(parser) + parser = add_dmd2_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--embedded_guidance", type=float, default=1.0, help="Flux.2 embedded guidance value.") parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") return parser @@ -98,6 +115,8 @@ def flux2_parser(): if __name__ == "__main__": parser = flux2_parser() args = parser.parse_args() + if args.task == "dmd2": + args.find_unused_parameters = True accelerator = accelerate.Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -133,10 +152,12 @@ def flux2_parser(): extra_inputs=args.extra_inputs, fp8_models=args.fp8_models, offload_models=args.offload_models, + embedded_guidance=args.embedded_guidance, template_model_id_or_path=args.template_model_id_or_path, resume_from_checkpoint=args.resume_from_checkpoint, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, enable_lora_hot_loading=args.enable_lora_hot_loading, + dmd2_config=DMD2Config.from_args(args), task=args.task, device="cpu" if (args.initialize_model_on_cpu or args.enable_model_cpu_offload) else accelerator.device, ) @@ -156,5 +177,6 @@ def flux2_parser(): "sft:train": launch_training_task, "direct_distill": launch_training_task, "direct_distill:train": launch_training_task, + "dmd2": launch_dmd2_training_task, } launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py new file mode 100644 index 000000000..ead3cc881 --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda:2", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_dmd2-4steps/step-200.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) + +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=4, height=512, width=512) +image.save("image_FLUX.2-klein-base-4B-DMD2-4steps.jpg") \ No newline at end of file From 04559ba307b19572aee467de8761e96c65b67d5e Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Thu, 18 Jun 2026 16:26:22 +0800 Subject: [PATCH 2/6] fix: examples --- .../special/dmd2/FLUX.2-klein-base-4B-DMD2.sh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh index c74d329a7..49e2b190e 100644 --- a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh +++ b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh @@ -1,15 +1,17 @@ -accelerate launch --num_processes 1 --num_machines 1 --mixed_precision bf16 --dynamo_backend no examples/flux2/model_training/train.py \ - --dataset_base_path /mnt/nas1/yejinyan.yjy/diffsynth-study/DiffSynth-Studio/data/OmniData/midjourney-v6-520k-raw-diffsynth \ - --dataset_metadata_path /mnt/nas1/yejinyan.yjy/diffsynth-study/DiffSynth-Studio/data/OmniData/midjourney-v6-520k-raw-diffsynth/metadata.csv \ +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/FLUX.2-klein-base-4B/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B \ + --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B/metadata.csv \ --height 512 \ --width 512 \ - --dataset_repeat 1 \ + --dataset_repeat 100 \ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ --learning_rate 1e-5 \ --dmd2_fake_score_learning_rate 1e-5 \ --dmd2_discriminator_learning_rate 1e-5 \ - --num_epochs 1 \ + --num_epochs 10 \ --save_steps 1000 \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/FLUX.2-klein-base-4B_dmd2-4steps" \ From 09f6843cbe1da3cc7809a34ca49ed8df7554a142 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Thu, 18 Jun 2026 16:56:34 +0800 Subject: [PATCH 3/6] fix: examples --- .../model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py index ead3cc881..a60844230 100644 --- a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py @@ -12,7 +12,7 @@ ], tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), ) -state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_dmd2-4steps/step-200.safetensors", torch_dtype=torch.bfloat16) +state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_dmd2-4steps/step-2000.safetensors", torch_dtype=torch.bfloat16) pipe.dit.load_state_dict(state_dict) prompt = "a dog" From 9fc4bbca38edd1b1b5c8448ca022304d79850ed2 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Thu, 18 Jun 2026 17:10:13 +0800 Subject: [PATCH 4/6] fix: examples --- .../model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh | 4 ---- .../model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh index 49e2b190e..097abd6d9 100644 --- a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh +++ b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh @@ -19,12 +19,8 @@ accelerate launch examples/flux2/model_training/train.py \ --task dmd2 \ --dmd2_student_sample_steps 4 \ --dmd2_student_sample_type sde \ - --dmd2_student_schedule uniform \ --dmd2_student_update_freq 5 \ - --dmd2_matching_t_sampling uniform \ --dmd2_gan_loss_weight 0.03 \ - --dmd2_matching_t_min 0.001 \ - --dmd2_matching_t_max 0.999 \ --dmd2_feature_indices 12 \ --embedded_guidance 4 \ --dmd2_teacher_cfg_scale 4 \ diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py index a60844230..f269d4e2e 100644 --- a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py @@ -4,7 +4,7 @@ pipe = Flux2ImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, - device="cuda:2", + device="cuda", model_configs=[ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), From 3a4597870fa9eaa241ebc83e534b2a9b6073bb6e Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Thu, 18 Jun 2026 18:47:07 +0800 Subject: [PATCH 5/6] fix: examples --- .../special/dmd2/FLUX.2-klein-base-4B-DMD2.sh | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh index 097abd6d9..9ec988172 100644 --- a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh +++ b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh @@ -24,4 +24,32 @@ accelerate launch examples/flux2/model_training/train.py \ --dmd2_feature_indices 12 \ --embedded_guidance 4 \ --dmd2_teacher_cfg_scale 4 \ - --use_gradient_checkpointing \ No newline at end of file + --use_gradient_checkpointing + + +# 使用更完整的训练数据 +# accelerate launch --num_processes 1 --num_machines 1 --mixed_precision bf16 --dynamo_backend no examples/flux2/model_training/train.py \ +# --dataset_base_path /mnt/nas1/yejinyan.yjy/diffsynth-study/DiffSynth-Studio/data/OmniData/midjourney-v6-520k-raw-diffsynth \ +# --dataset_metadata_path /mnt/nas1/yejinyan.yjy/diffsynth-study/DiffSynth-Studio/data/OmniData/midjourney-v6-520k-raw-diffsynth/metadata.csv \ + # --height 512 \ + # --width 512 \ + # --dataset_repeat 100 \ + # --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + # --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + # --learning_rate 1e-5 \ + # --dmd2_fake_score_learning_rate 1e-5 \ + # --dmd2_discriminator_learning_rate 1e-5 \ + # --num_epochs 10 \ + # --save_steps 1000 \ + # --remove_prefix_in_ckpt "pipe.dit." \ + # --output_path "./models/train/FLUX.2-klein-base-4B_dmd2-4steps" \ + # --trainable_models "dit" \ + # --task dmd2 \ + # --dmd2_student_sample_steps 4 \ + # --dmd2_student_sample_type sde \ + # --dmd2_student_update_freq 5 \ + # --dmd2_gan_loss_weight 0.03 \ + # --dmd2_feature_indices 12 \ + # --embedded_guidance 4 \ + # --dmd2_teacher_cfg_scale 4 \ + # --use_gradient_checkpointing From c509ae0b7e77043a5f35f0eeacfa624366cd3c52 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Fri, 26 Jun 2026 15:27:26 +0800 Subject: [PATCH 6/6] fix: dmd2 for flux2-klein-base-4B --- diffsynth/diffusion/dmd2.py | 431 +++++---------- diffsynth/diffusion/parsers.py | 21 + diffsynth/pipelines/flux2_image.py | 126 +---- .../special/dmd2/FLUX.2-klein-base-4B-DMD2.sh | 34 +- .../model_training/special/dmd2/train.py | 514 ++++++++++++++++++ examples/flux2/model_training/train.py | 30 +- .../FLUX.2-klein-base-4B-DMD2.py | 4 +- 7 files changed, 691 insertions(+), 469 deletions(-) create mode 100644 examples/flux2/model_training/special/dmd2/train.py diff --git a/diffsynth/diffusion/dmd2.py b/diffsynth/diffusion/dmd2.py index f5d9e2cd5..091d1ef17 100644 --- a/diffsynth/diffusion/dmd2.py +++ b/diffsynth/diffusion/dmd2.py @@ -1,4 +1,3 @@ -import copy from dataclasses import dataclass from typing import Optional import torch @@ -6,18 +5,7 @@ from accelerate import Accelerator from tqdm import tqdm from .runner import get_optimizer_class, initialize_deepspeed_gradient_checkpointing - - -def _parse_int_list(value): - if value is None or value == "": - return None - return [int(i) for i in value.split(",") if i != ""] - - -def _parse_float_list(value): - if value is None or value == "": - return None - return [float(i) for i in value.split(",") if i != ""] +from diffsynth.core import OffloadTrainingManager @dataclass @@ -25,7 +13,7 @@ class DMD2Config: student_update_freq: int = 5 student_sample_steps: int = 4 student_sample_type: str = "sde" - student_schedule: str = "uniform" + student_schedule: str = "uniform" student_t_list: Optional[list[float]] = None matching_t_min: float = 0.001 matching_t_max: float = 0.999 @@ -35,190 +23,73 @@ class DMD2Config: gan_loss_weight: float = 0.03 gan_r1_reg_weight: float = 0.0 gan_r1_reg_alpha: float = 0.1 - gan_logit_reg_weight: float = 0.0 fake_score_learning_rate: Optional[float] = None discriminator_learning_rate: Optional[float] = None feature_indices: Optional[list[int]] = None - discriminator_hidden_dim: Optional[int] = None - discriminator_num_blocks: Optional[int] = None - teacher_cfg_scale: Optional[float] = None + teacher_cfg_scale: float = 1.0 student_grad_clip_norm: Optional[float] = 10.0 - @classmethod - def from_args(cls, args): - return cls( - student_update_freq=args.dmd2_student_update_freq, - student_sample_steps=args.dmd2_student_sample_steps, - student_sample_type=args.dmd2_student_sample_type, - student_schedule=args.dmd2_student_schedule, - student_t_list=_parse_float_list(args.dmd2_student_t_list), - matching_t_min=args.dmd2_matching_t_min, - matching_t_max=args.dmd2_matching_t_max, - matching_t_sampling=args.dmd2_matching_t_sampling, - matching_t_mean=args.dmd2_matching_t_mean, - matching_t_std=args.dmd2_matching_t_std, - gan_loss_weight=args.dmd2_gan_loss_weight, - gan_r1_reg_weight=args.dmd2_gan_r1_reg_weight, - gan_r1_reg_alpha=args.dmd2_gan_r1_reg_alpha, - gan_logit_reg_weight=args.dmd2_gan_logit_reg_weight, - fake_score_learning_rate=args.dmd2_fake_score_learning_rate, - discriminator_learning_rate=args.dmd2_discriminator_learning_rate, - feature_indices=_parse_int_list(args.dmd2_feature_indices), - discriminator_hidden_dim=args.dmd2_discriminator_hidden_dim, - discriminator_num_blocks=args.dmd2_discriminator_num_blocks, - teacher_cfg_scale=args.dmd2_teacher_cfg_scale, - student_grad_clip_norm=args.dmd2_student_grad_clip_norm, - ) +def _get_dmd2_pipe_model(module, model_name): + if model_name is None: + return None + return getattr(module.pipe, model_name) -def add_dmd2_config(parser): - parser.add_argument("--dmd2_student_update_freq", type=int, default=5, help="Update student once every N DMD2 iterations.") - parser.add_argument("--dmd2_student_sample_steps", type=int, default=4, help="Number of distilled student sampling steps.") - parser.add_argument("--dmd2_student_sample_type", type=str, default="sde", choices=["sde", "ode"], help="Student sampling type used by the DMD2 objective.") - parser.add_argument("--dmd2_student_schedule", type=str, default="uniform", choices=["uniform"], help="Student sigma schedule.") - parser.add_argument("--dmd2_student_t_list", type=str, default=None, help="Optional student sigma schedule, including the final 0.") - parser.add_argument("--dmd2_matching_t_min", type=float, default=0.001, help="Minimum matching sigma sampled for DMD2.") - parser.add_argument("--dmd2_matching_t_max", type=float, default=0.999, help="Maximum matching sigma sampled for DMD2.") - parser.add_argument("--dmd2_matching_t_sampling", type=str, default="uniform", choices=["uniform", "logitnormal"], help="Sample matching sigma.") - parser.add_argument("--dmd2_matching_t_mean", type=float, default=0.0, help="Mean for logitnormal matching timestep sampling.") - parser.add_argument("--dmd2_matching_t_std", type=float, default=1.0, help="Std for logitnormal matching timestep sampling.") - parser.add_argument("--dmd2_gan_loss_weight", type=float, default=0.03, help="Generator GAN loss weight.") - parser.add_argument("--dmd2_gan_r1_reg_weight", type=float, default=0.0, help="Approximate R1 regularization weight for the discriminator.") - parser.add_argument("--dmd2_gan_r1_reg_alpha", type=float, default=0.1, help="Noise scale for approximate R1 regularization.") - parser.add_argument("--dmd2_gan_logit_reg_weight", type=float, default=0.0, help="L2 regularization weight on discriminator logits.") - parser.add_argument("--dmd2_fake_score_learning_rate", type=float, default=None, help="Learning rate for the fake score model.") - parser.add_argument("--dmd2_discriminator_learning_rate", type=float, default=None, help="Learning rate for the discriminator.") - parser.add_argument("--dmd2_feature_indices", type=str, default=None, help="DiT block indices used by the discriminator.") - parser.add_argument("--dmd2_discriminator_hidden_dim", type=int, default=None, help="Hidden dimension of DiT features.") - parser.add_argument("--dmd2_discriminator_num_blocks", type=int, default=None, help="Total Flux block count.") - parser.add_argument("--dmd2_teacher_cfg_scale", type=float, default=None, help="CFG scale applied to the teacher x0 in DMD2.") - parser.add_argument("--dmd2_student_grad_clip_norm", type=float, default=10.0, help="Clip student gradients to this norm.") - return parser - - -def _get_optimal_groups(num_channels): - if num_channels <= 32: - groups = max(1, num_channels // 4) - else: - groups = 32 - while groups > 1 and num_channels % groups != 0: - groups -= 1 - assert num_channels % groups == 0, f"{num_channels} not divisible by {groups}" - return groups - - -class FluxDMD2Discriminator(torch.nn.Module): - def __init__(self, feature_indices=None, num_blocks=None, inner_dim=None): - super().__init__() - if num_blocks is None: - raise ValueError("`num_blocks` must be provided.") - if inner_dim is None: - raise ValueError("`inner_dim` must be provided.") - if feature_indices is None: - feature_indices = [int(num_blocks // 2)] - self.feature_indices = sorted({int(i) for i in feature_indices if 0 <= int(i) < num_blocks}) - if len(self.feature_indices) == 0: - raise ValueError("DMD2 discriminator requires at least one valid feature index.") - self.num_features = len(self.feature_indices) - self.inner_dim = inner_dim - - hidden_channels = inner_dim // 2 - self.heads = torch.nn.ModuleList([ - torch.nn.Sequential( - torch.nn.Conv2d(inner_dim, hidden_channels, kernel_size=4, stride=2, padding=1), - torch.nn.GroupNorm(_get_optimal_groups(hidden_channels), hidden_channels), - torch.nn.LeakyReLU(0.2), - torch.nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, padding=0), - torch.nn.AdaptiveAvgPool2d((1, 1)), - torch.nn.Flatten(), - ) - for _ in self.feature_indices - ]) - - def forward(self, feats): - if not isinstance(feats, list) or len(feats) != self.num_features: - raise ValueError(f"Expected list of {self.num_features} feature tensors, got {type(feats)} with length {len(feats) if isinstance(feats, list) else 'N/A'}.") - logits = [] - for head, feat in zip(self.heads, feats): - param = next(head.parameters()) - feat = feat.to(device=param.device, dtype=param.dtype) - logits.append(head(feat)) - return torch.cat(logits, dim=1) - - -def _infer_dit_num_blocks(dit, default=40): - if hasattr(dit, "blocks") and hasattr(dit, "single_blocks"): - return len(dit.blocks) + len(dit.single_blocks) - if hasattr(dit, "transformer_blocks") and hasattr(dit, "single_transformer_blocks"): - return len(dit.transformer_blocks) + len(dit.single_transformer_blocks) - return default - - -def _infer_dit_hidden_dim(dit, default=3072): - return int(getattr(dit, "inner_dim", default)) - - -def setup_dmd2_training(module, pipe, config: DMD2Config): - module.dmd2_config = config - module.teacher_dit = copy.deepcopy(pipe.dit).eval().requires_grad_(False) - module.fake_score = copy.deepcopy(pipe.dit).train().requires_grad_(True) - module.discriminator = None - if config.gan_loss_weight > 0: - discriminator_num_blocks = config.discriminator_num_blocks or _infer_dit_num_blocks(pipe.dit) - discriminator_hidden_dim = config.discriminator_hidden_dim or _infer_dit_hidden_dim(pipe.dit) - module.discriminator = FluxDMD2Discriminator( - feature_indices=config.feature_indices, - num_blocks=discriminator_num_blocks, - inner_dim=discriminator_hidden_dim, - ) - module.dmd2_loss = DMD2Loss(config) - module._dmd2_student_param_names = {name for name, param in pipe.dit.named_parameters() if param.requires_grad} - module._dmd2_fake_score_param_names = {name for name, _ in module.fake_score.named_parameters()} - def input_processor(inputs_shared): - return prepare_dmd2_pipeline_inputs(inputs_shared, config) +def get_dmd2_student_model(module): + return _get_dmd2_pipe_model(module, module.dmd2_student_model_name) - def loss_fn(pipe, inputs_shared, inputs_posi, inputs_nega, iteration=None, **kwargs): - return module.dmd2_loss( - module, - (inputs_shared, inputs_posi, inputs_nega), - 0 if iteration is None else iteration, - ) - def state_dict_exporter(state_dict, remove_prefix=None): - return export_dmd2_trainable_state_dict(module, state_dict, remove_prefix=remove_prefix) +def get_dmd2_teacher_model(module): + return _get_dmd2_pipe_model(module, module.dmd2_teacher_model_name) - module.task_to_input_processor["dmd2"] = input_processor - module.task_to_loss["dmd2"] = loss_fn - module.task_to_state_dict_exporter["dmd2"] = state_dict_exporter +def get_dmd2_fake_score_model(module): + return _get_dmd2_pipe_model(module, module.dmd2_fake_score_model_name) -def prepare_dmd2_pipeline_inputs(inputs_shared, config: DMD2Config): - if _cfg_enabled(config.teacher_cfg_scale): - inputs_shared["cfg_scale"] = config.teacher_cfg_scale - return inputs_shared + +def get_dmd2_discriminator(module): + return _get_dmd2_pipe_model(module, module.dmd2_discriminator_model_name) + + +def _dmd2_model_state_names(module, model_name, param_names=None): + if model_name is None: + return set() + model = getattr(module.pipe, model_name, None) + if model is None: + return set() + names = model.state_dict().keys() if param_names is None else param_names + return {f"pipe.{model_name}.{name}" for name in names} def export_dmd2_trainable_state_dict(module, state_dict, remove_prefix=None): - student_names = {"pipe.dit." + name for name in module._dmd2_student_param_names} - state_dict = {name: param for name, param in state_dict.items() if name in student_names} + student_names = _dmd2_model_state_names(module, module.dmd2_student_model_name, module._dmd2_student_param_names) + state_names = set(student_names) + if remove_prefix is None: + state_names.update(_dmd2_model_state_names(module, module.dmd2_teacher_model_name)) + state_names.update(_dmd2_model_state_names(module, module.dmd2_fake_score_model_name)) + state_names.update(_dmd2_model_state_names(module, module.dmd2_discriminator_model_name)) + state_dict = {name: param for name, param in state_dict.items() if name in state_names} if remove_prefix is not None: state_dict = {name[len(remove_prefix):] if name.startswith(remove_prefix) else name: param for name, param in state_dict.items()} return state_dict def set_dmd2_train_phase(module, student_phase: bool): - module.pipe.dit.train(student_phase) - for name, param in module.pipe.dit.named_parameters(): + student_model = get_dmd2_student_model(module) + student_model.train(student_phase) + for name, param in student_model.named_parameters(): param.requires_grad = student_phase and name in module._dmd2_student_param_names - module.fake_score.train(not student_phase) - for name, param in module.fake_score.named_parameters(): + fake_score_model = get_dmd2_fake_score_model(module) + fake_score_model.train(not student_phase) + for name, param in fake_score_model.named_parameters(): param.requires_grad = (not student_phase) and name in module._dmd2_fake_score_param_names - if module.discriminator is not None: - module.discriminator.train(not student_phase) - module.discriminator.requires_grad_(not student_phase) + discriminator = get_dmd2_discriminator(module) + if discriminator is not None: + discriminator.train(not student_phase) + discriminator.requires_grad_(not student_phase) def _expand_like(value, target, dtype=None): @@ -232,7 +103,7 @@ def _expand_like(value, target, dtype=None): return value -def _flow_to_x0(latents, flow, sigma): +def flow_to_x0(latents, flow, sigma): original_dtype = latents.dtype latents = latents.to(torch.float64) flow = flow.to(torch.float64) @@ -290,11 +161,6 @@ def _variational_score_distillation_loss(gen_data, teacher_x0, fake_score_x0): return loss -def _mean_abs_by_sample(value): - dims = tuple(range(1, value.ndim)) - return value.detach().float().abs().mean(dim=dims).mean() - - def _gan_loss_generator(fake_logits): assert fake_logits.ndim == 2, f"fake_logits has shape {fake_logits.shape}" gan_loss = F.softplus(-fake_logits).mean() @@ -308,66 +174,62 @@ def _gan_loss_discriminator(real_logits, fake_logits): return gan_loss -def _cfg_enabled(cfg_scale): - return cfg_scale is not None and abs(float(cfg_scale) - 1.0) > 1e-12 - - class DMD2Loss: def __init__(self, config: DMD2Config): self.config = config - self._last_teacher_cfg_delta = None def _model_forward_x0( self, module, - dit, + model, + model_fn, latents, timestep, sigma, inputs_shared, inputs_posi, ): - pipe = module.pipe - models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} - models["dit"] = dit shared = dict(inputs_shared) posi = dict(inputs_posi) shared["latents"] = latents - flow = pipe.model_fn( - **models, - **shared, - **posi, + flow = model_fn( + module.pipe, + model, timestep=timestep, progress_id=0, num_inference_steps=1, + inputs_shared=shared, + inputs_posi=posi, ) - return _flow_to_x0(latents, flow, sigma) + return flow_to_x0(latents, flow, sigma) def _model_forward_features( self, module, - dit, + model, + model_fn, latents, timestep, inputs_shared, inputs_posi, ): - if module.discriminator is None: + discriminator = get_dmd2_discriminator(module) + if discriminator is None: raise ValueError("DMD2 feature extraction requires a discriminator.") - pipe = module.pipe - models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} - models["dit"] = dit + if self.config.feature_indices is None: + raise ValueError("DMD2 feature extraction requires `dmd2_feature_indices`.") shared = dict(inputs_shared) posi = dict(inputs_posi) shared["latents"] = latents - return pipe.model_fn( - **models, - **shared, - **posi, + return model_fn( + module.pipe, + model, timestep=timestep, progress_id=0, num_inference_steps=1, - feature_indices=set(module.discriminator.feature_indices), + inputs_shared=shared, + inputs_posi=posi, + feature_indices=set(self.config.feature_indices), return_features=True, ) @@ -383,46 +245,29 @@ def _teacher_forward_x0( ): teacher_x0_pos = self._model_forward_x0( module, - module.teacher_dit, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, latents, timestep, sigma, inputs_shared, inputs_posi, ) - if not _cfg_enabled(self.config.teacher_cfg_scale): - self._last_teacher_cfg_delta = torch.zeros((), device=latents.device, dtype=torch.float32) + if self.config.teacher_cfg_scale <= 1.0: return teacher_x0_pos + x0_neg = self._model_forward_x0( module, - module.teacher_dit, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, latents, timestep, sigma, inputs_shared, inputs_nega, ) - x0 = x0_neg + float(self.config.teacher_cfg_scale) * (teacher_x0_pos - x0_neg) - self._last_teacher_cfg_delta = _mean_abs_by_sample(x0 - teacher_x0_pos) - return x0 - - def _teacher_forward_features( - self, - module, - latents, - timestep, - inputs_shared, - inputs_posi, - ): - return self._model_forward_features( - module, - module.teacher_dit, - latents, - timestep, - inputs_shared, - inputs_posi, - ) + return x0_neg + float(self.config.teacher_cfg_scale) * (teacher_x0_pos - x0_neg) def _sample_matching_timestep(self, pipe, device, dtype, batch_size=1, inputs_shared=None, real_data=None): if self.config.matching_t_min > self.config.matching_t_max: @@ -471,7 +316,8 @@ def _generate_student_data(self, module, real_data, inputs_shared, inputs_posi): input_student = _forward_process(real_data, eps_student, sigma) gen_data = self._model_forward_x0( module, - module.pipe.dit, + get_dmd2_student_model(module), + module.dmd2_model_fn_student, input_student, timestep, sigma, @@ -485,7 +331,8 @@ def _compute_real_feat(self, module, real_data, timestep, sigma, eps, inputs_sha perturbed_real = _forward_process(real_data, real_eps, real_sigma) real_feat = self._model_forward_features( module, - module.teacher_dit, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, perturbed_real, real_timestep, inputs_shared, @@ -509,7 +356,8 @@ def _student_update_step(self, module, real_data, inputs_shared, inputs_posi, in with torch.no_grad(): fake_score_x0 = self._model_forward_x0( module, - module.fake_score, + get_dmd2_fake_score_model(module), + module.dmd2_model_fn_fake_score, perturbed_data, timestep, sigma, @@ -517,50 +365,30 @@ def _student_update_step(self, module, real_data, inputs_shared, inputs_posi, in inputs_posi, ) + with torch.no_grad(): + teacher_x0 = self._teacher_forward_x0( + module, perturbed_data, timestep, sigma, inputs_shared, inputs_posi, inputs_nega + ) if self.config.gan_loss_weight > 0: - with torch.no_grad(): - teacher_x0 = self._teacher_forward_x0( - module, perturbed_data, timestep, sigma, inputs_shared, inputs_posi, inputs_nega - ) - fake_feat = self._teacher_forward_features( + fake_feat = self._model_forward_features( module, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, perturbed_data, timestep, inputs_shared, inputs_posi, ) - fake_logits_gen = module.discriminator(fake_feat) + fake_logits_gen = get_dmd2_discriminator(module)(module.pipe, fake_feat) gan_loss_gen = _gan_loss_generator(fake_logits_gen) else: - with torch.no_grad(): - teacher_x0 = self._teacher_forward_x0( - module, perturbed_data, timestep, sigma, inputs_shared, inputs_posi, inputs_nega - ) gan_loss_gen = torch.zeros((), device=real_data.device, dtype=torch.float32) vsd_loss = _variational_score_distillation_loss(gen_data, teacher_x0.detach(), fake_score_x0) gan_loss_weighted = self.config.gan_loss_weight * gan_loss_gen loss = vsd_loss + gan_loss_weighted - with torch.no_grad(): - teacher_delta = _mean_abs_by_sample(gen_data - teacher_x0) - fake_delta = _mean_abs_by_sample(gen_data - fake_score_x0) - vsd_delta = _mean_abs_by_sample(fake_score_x0 - teacher_x0) - effective_gan_weight = gan_loss_weighted.detach() / gan_loss_gen.detach().clamp_min(1e-12) - - return { - "total_loss": loss, - "vsd_loss": vsd_loss.detach(), - "gan_loss_gen": gan_loss_gen.detach(), - "gan_loss_gen_weighted": gan_loss_weighted.detach(), - "gan_loss_effective_weight": effective_gan_weight, - "dmd2_teacher_delta": teacher_delta, - "dmd2_teacher_cfg_delta": self._last_teacher_cfg_delta.detach(), - "dmd2_fake_score_delta": fake_delta, - "dmd2_vsd_delta": vsd_delta, - "dmd2_sigma_mean": sigma.detach().float().mean(), - "student_input_mean": input_student.detach().float().mean(), - } + return {"total_loss": loss} def _fake_score_discriminator_update_step(self, module, real_data, inputs_shared, inputs_posi, inputs_nega): with torch.no_grad(): @@ -578,7 +406,8 @@ def _fake_score_discriminator_update_step(self, module, real_data, inputs_shared fake_score_x0 = self._model_forward_x0( module, - module.fake_score, + get_dmd2_fake_score_model(module), + module.dmd2_model_fn_fake_score, x_t_sg, timestep, sigma, @@ -586,19 +415,16 @@ def _fake_score_discriminator_update_step(self, module, real_data, inputs_shared inputs_posi, ) fake_score_loss = F.mse_loss(fake_score_x0.float(), gen_data.float(), reduction="mean") - with torch.no_grad(): - fake_score_delta = _mean_abs_by_sample(fake_score_x0 - gen_data) gan_loss_disc = torch.zeros_like(fake_score_loss) gan_loss_ar1 = torch.zeros_like(fake_score_loss) gan_loss_logit_reg = torch.zeros_like(fake_score_loss) - real_logit_mean = torch.zeros_like(fake_score_loss) - fake_logit_mean = torch.zeros_like(fake_score_loss) if self.config.gan_loss_weight > 0: with torch.no_grad(): fake_feat = self._model_forward_features( module, - module.teacher_dit, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, x_t_sg, timestep, inputs_shared, @@ -607,45 +433,31 @@ def _fake_score_discriminator_update_step(self, module, real_data, inputs_shared real_feat, real_timestep, real_sigma = self._compute_real_feat( module, real_data, timestep, sigma, eps, inputs_shared, inputs_posi ) - real_logits = module.discriminator(real_feat) - fake_logits = module.discriminator(fake_feat) - real_logit_mean = real_logits.detach().float().mean() - fake_logit_mean = fake_logits.detach().float().mean() + discriminator = get_dmd2_discriminator(module) + real_logits = discriminator(module.pipe, real_feat) + fake_logits = discriminator(module.pipe, fake_feat) gan_loss_disc = _gan_loss_discriminator(real_logits, fake_logits) - if self.config.gan_logit_reg_weight > 0: - gan_loss_logit_reg = 0.5 * (real_logits.float().square().mean() + fake_logits.float().square().mean()) if self.config.gan_r1_reg_weight > 0: perturbed_real_alpha = real_data + self.config.gan_r1_reg_alpha * torch.randn_like(real_data) with torch.no_grad(): real_feat_alpha = self._model_forward_features( module, - module.teacher_dit, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, perturbed_real_alpha, real_timestep, inputs_shared, inputs_posi, ) - real_logits_alpha = module.discriminator(real_feat_alpha) + real_logits_alpha = discriminator(module.pipe, real_feat_alpha) gan_loss_ar1 = F.mse_loss(real_logits, real_logits_alpha, reduction="mean") loss = ( fake_score_loss + gan_loss_disc + self.config.gan_r1_reg_weight * gan_loss_ar1 - + self.config.gan_logit_reg_weight * gan_loss_logit_reg ) - return { - "total_loss": loss, - "fake_score_loss": fake_score_loss.detach(), - "fake_score_delta": fake_score_delta, - "gan_loss_disc": gan_loss_disc.detach(), - "gan_loss_ar1": gan_loss_ar1.detach(), - "gan_loss_logit_reg": gan_loss_logit_reg.detach(), - "gan_loss_logit_reg_weighted": (self.config.gan_logit_reg_weight * gan_loss_logit_reg).detach(), - "gan_real_logit": real_logit_mean, - "gan_fake_logit": fake_logit_mean, - "dmd2_sigma_mean": sigma.detach().float().mean(), - } + return {"total_loss": loss} def __call__(self, module, inputs, iteration): inputs_shared, inputs_posi, inputs_nega = inputs @@ -684,6 +496,9 @@ def launch_dmd2_training_task( num_workers: int = 1, save_steps: int = None, num_epochs: int = 1, + enable_model_cpu_offload: bool = False, + enable_optimizer_cpu_offload: bool = False, + cpu_offload_split_threshold: int = None, customized_optimizer: str = None, args=None, **kwargs, @@ -694,15 +509,23 @@ def launch_dmd2_training_task( num_workers = args.dataset_num_workers save_steps = args.save_steps num_epochs = args.num_epochs + enable_model_cpu_offload = args.enable_model_cpu_offload + enable_optimizer_cpu_offload = args.enable_optimizer_cpu_offload + cpu_offload_split_threshold = args.cpu_offload_split_threshold customized_optimizer = args.customized_optimizer optimizer_class = get_optimizer_class(customized_optimizer) - config = model.dmd2_config + config = model.dmd2_loss.config fake_score_lr = config.fake_score_learning_rate or learning_rate discriminator_lr = config.discriminator_learning_rate or learning_rate - student_optimizer = optimizer_class(_trainable_params(model.pipe.dit), lr=learning_rate, weight_decay=weight_decay) - fake_score_optimizer = optimizer_class(model.fake_score.parameters(), lr=fake_score_lr, weight_decay=weight_decay) + student_model = get_dmd2_student_model(model) + fake_score_model = get_dmd2_fake_score_model(model) + discriminator = get_dmd2_discriminator(model) + has_discriminator = discriminator is not None + + student_optimizer = optimizer_class(_trainable_params(student_model), lr=learning_rate, weight_decay=weight_decay) + fake_score_optimizer = optimizer_class(fake_score_model.parameters(), lr=fake_score_lr, weight_decay=weight_decay) student_scheduler = torch.optim.lr_scheduler.ConstantLR(student_optimizer, factor=1.0, total_iters=1) fake_score_scheduler = torch.optim.lr_scheduler.ConstantLR(fake_score_optimizer, factor=1.0, total_iters=1) @@ -713,8 +536,8 @@ def launch_dmd2_training_task( "fake_score_scheduler": fake_score_scheduler, } prepare_items = [model, student_optimizer, fake_score_optimizer, student_scheduler, fake_score_scheduler] - if model.discriminator is not None: - discriminator_optimizer = optimizer_class(model.discriminator.parameters(), lr=discriminator_lr, weight_decay=weight_decay) + if has_discriminator: + discriminator_optimizer = optimizer_class(discriminator.parameters(), lr=discriminator_lr, weight_decay=weight_decay) discriminator_scheduler = torch.optim.lr_scheduler.ConstantLR(discriminator_optimizer, factor=1.0, total_iters=1) optimizers["discriminator"] = discriminator_optimizer optimizers["discriminator_scheduler"] = discriminator_scheduler @@ -722,16 +545,28 @@ def launch_dmd2_training_task( dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) prepare_items.append(dataloader) - model.to(device=accelerator.device) - prepared = accelerator.prepare(*prepare_items) - model = prepared[0] - prepared_tail = list(prepared[1:]) + if enable_model_cpu_offload: + prepared = accelerator.prepare(*prepare_items[1:]) + model.pipe.device = accelerator.device + offload_manager = OffloadTrainingManager( + model, + accelerator.device, + enable_optimizer_cpu_offload, + cpu_offload_split_threshold, + ) + prepared_tail = list(prepared) + else: + model.to(device=accelerator.device) + prepared = accelerator.prepare(*prepare_items) + model = prepared[0] + prepared_tail = list(prepared[1:]) + optimizers["student"] = prepared_tail.pop(0) optimizers["fake_score"] = prepared_tail.pop(0) optimizers["student_scheduler"] = prepared_tail.pop(0) optimizers["fake_score_scheduler"] = prepared_tail.pop(0) - if model.discriminator is not None: + if has_discriminator: optimizers["discriminator"] = prepared_tail.pop(0) optimizers["discriminator_scheduler"] = prepared_tail.pop(0) dataloader = prepared_tail.pop(0) @@ -748,15 +583,17 @@ def launch_dmd2_training_task( loss = loss_map["total_loss"] current_optimizers, current_schedulers = _dmd2_current_optimizers(config, optimizers, iteration) accelerator.backward(loss) + if enable_model_cpu_offload: + offload_manager.after_backward() if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0: - accelerator.clip_grad_norm_(_trainable_params(model.pipe.dit), config.student_grad_clip_norm) + accelerator.clip_grad_norm_(_trainable_params(get_dmd2_student_model(accelerator.unwrap_model(model))), config.student_grad_clip_norm) for optimizer in current_optimizers: optimizer.step() for scheduler in current_schedulers: scheduler.step() for optimizer in current_optimizers: optimizer.zero_grad(set_to_none=True) - model_logger.on_step_end(accelerator, model, save_steps, loss=loss, metrics=loss_map) + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) iteration += 1 if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index c2c0b6ef7..cc190516b 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -81,6 +81,27 @@ def add_logger_config(parser: argparse.ArgumentParser): parser.add_argument("--wandb_project", type=str, default="DiffSynth-Studio", help="Wandb project name.") return parser +def add_dmd2_config(parser: argparse.ArgumentParser): + parser.add_argument("--dmd2_student_update_freq", type=int, default=5, help="Update student once every N DMD2 iterations.") + parser.add_argument("--dmd2_student_sample_steps", type=int, default=4, help="Number of distilled student sampling steps.") + parser.add_argument("--dmd2_student_sample_type", type=str, default="sde", choices=["sde", "ode"], help="Student sampling type.") + parser.add_argument("--dmd2_student_schedule", type=str, default="uniform", choices=["uniform"], help="Student sigma schedule.") + parser.add_argument("--dmd2_student_t_list", type=str, default=None, help="Optional student sigma schedule, including the final 0.") + parser.add_argument("--dmd2_matching_t_min", type=float, default=0.001, help="Minimum matching sigma sampled for DMD2.") + parser.add_argument("--dmd2_matching_t_max", type=float, default=0.999, help="Maximum matching sigma sampled for DMD2.") + parser.add_argument("--dmd2_matching_t_sampling", type=str, default="uniform", choices=["uniform", "logitnormal"], help="Sample matching sigma.") + parser.add_argument("--dmd2_matching_t_mean", type=float, default=0.0, help="Mean for logitnormal matching timestep sampling.") + parser.add_argument("--dmd2_matching_t_std", type=float, default=1.0, help="Std for logitnormal matching timestep sampling.") + parser.add_argument("--dmd2_gan_loss_weight", type=float, default=0.03, help="Generator GAN loss weight.") + parser.add_argument("--dmd2_gan_r1_reg_weight", type=float, default=0.0, help="Approximate R1 regularization weight for the discriminator.") + parser.add_argument("--dmd2_gan_r1_reg_alpha", type=float, default=0.1, help="Noise scale for approximate R1 regularization.") + parser.add_argument("--dmd2_fake_score_learning_rate", type=float, default=None, help="Learning rate for the fake score model.") + parser.add_argument("--dmd2_discriminator_learning_rate", type=float, default=None, help="Learning rate for the discriminator.") + parser.add_argument("--dmd2_feature_indices", type=str, default=None, help="Model feature indices used by the DMD2 discriminator.") + parser.add_argument("--dmd2_teacher_cfg_scale", type=float, default=1.0, help="CFG scale.") + parser.add_argument("--dmd2_student_grad_clip_norm", type=float, default=10.0, help="Clip student gradients to this norm.") + return parser + def add_general_config(parser: argparse.ArgumentParser): parser = add_dataset_base_config(parser) parser = add_model_config(parser) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 26f028955..753f6c086 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -616,129 +616,29 @@ def model_fn_flux2( extra_text_embedding=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, - feature_indices=None, - return_features=False, **kwargs, ): - feature_indices = set() if feature_indices is None else set(feature_indices) image_seq_len = latents.shape[1] if edit_latents is not None: image_seq_len = latents.shape[1] latents = torch.concat([latents, edit_latents], dim=1) image_ids = torch.concat([image_ids, edit_image_ids], dim=1) - if embedded_guidance is None: - embedded_guidance = None - elif isinstance(embedded_guidance, torch.Tensor): - embedded_guidance = embedded_guidance.to(device=latents.device, dtype=latents.dtype).flatten() - if embedded_guidance.numel() == 1: - embedded_guidance = embedded_guidance.expand(latents.shape[0]) - elif embedded_guidance.numel() != latents.shape[0]: - raise ValueError("`embedded_guidance` must be a scalar or match the latent batch size.") - else: - embedded_guidance = torch.full((latents.shape[0],), float(embedded_guidance), device=latents.device, dtype=latents.dtype) + embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) if extra_text_embedding is not None: extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device) extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1]) prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1) text_ids = torch.concat([text_ids, extra_text_ids], dim=1) - if not return_features: - model_output = dit( - hidden_states=latents, - timestep=timestep / 1000, - guidance=embedded_guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=image_ids, - kv_cache=kv_cache, - use_gradient_checkpointing=use_gradient_checkpointing, - use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - ) - model_output = model_output[:, :image_seq_len] - return model_output - - height, width = kwargs.get("height"), kwargs.get("width") - if height is not None and width is not None: - feature_height, feature_width = int(height) // 16, int(width) // 16 - else: - feature_height = int(math.sqrt(image_seq_len)) - feature_width = image_seq_len // feature_height if feature_height > 0 else 0 - if feature_height * feature_width != image_seq_len: - raise ValueError("Flux2 feature extraction requires height/width or square latent tokens.") - - features = [] - - def append_feature(feat): - feat = feat[:, :image_seq_len] - batch_size, _, channels = feat.shape - feat = feat.permute(0, 2, 1).reshape(batch_size, channels, feature_height, feature_width) - features.append(feat) - if len(features) == len(feature_indices): - return features - return None - - num_txt_tokens = prompt_embeds.shape[1] - timestep = timestep.to(latents.dtype) - guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000 - temb = dit.time_guidance_embed(timestep, guidance) - - double_stream_mod_img = dit.double_stream_modulation_img(temb) - double_stream_mod_txt = dit.double_stream_modulation_txt(temb) - single_stream_mod = dit.single_stream_modulation(temb)[0] - - hidden_states = dit.x_embedder(latents) - encoder_hidden_states = dit.context_embedder(prompt_embeds) - - if image_ids.ndim == 3: - image_ids = image_ids[0] - if text_ids.ndim == 3: - text_ids = text_ids[0] - - image_rotary_emb = dit.pos_embed(image_ids) - text_rotary_emb = dit.pos_embed(text_ids) - concat_rotary_emb = ( - torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), - torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + model_output = dit( + hidden_states=latents, + timestep=timestep / 1000, + guidance=embedded_guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=image_ids, + kv_cache=kv_cache, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) - - for block_id, block in enumerate(dit.transformer_blocks): - encoder_hidden_states, hidden_states = gradient_checkpoint_forward( - block, - use_gradient_checkpointing=use_gradient_checkpointing, - use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb_mod_params_img=double_stream_mod_img, - temb_mod_params_txt=double_stream_mod_txt, - image_rotary_emb=concat_rotary_emb, - joint_attention_kwargs=None, - kv_cache=None if kv_cache is None else kv_cache.get(f"double_{block_id}"), - ) - if block_id in feature_indices: - selected_features = append_feature(hidden_states) - if selected_features is not None: - return selected_features - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - num_double_blocks = len(dit.transformer_blocks) - - for block_id, block in enumerate(dit.single_transformer_blocks): - hidden_states = gradient_checkpoint_forward( - block, - use_gradient_checkpointing=use_gradient_checkpointing, - use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, - hidden_states=hidden_states, - encoder_hidden_states=None, - temb_mod_params=single_stream_mod, - image_rotary_emb=concat_rotary_emb, - joint_attention_kwargs=None, - kv_cache=None if kv_cache is None else kv_cache.get(f"single_{block_id}"), - ) - feature_id = block_id + num_double_blocks - if feature_id in feature_indices: - selected_features = append_feature(hidden_states[:, num_txt_tokens:num_txt_tokens + image_seq_len]) - if selected_features is not None: - return selected_features - - if len(features) != len(feature_indices): - raise ValueError(f"Only collected {len(features)} feature maps for {len(feature_indices)} requested feature indices.") - return features + model_output = model_output[:, :image_seq_len] + return model_output \ No newline at end of file diff --git a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh index 9ec988172..630764cfc 100644 --- a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh +++ b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh @@ -1,6 +1,6 @@ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/FLUX.2-klein-base-4B/*" --local_dir ./data/diffsynth_example_dataset -accelerate launch examples/flux2/model_training/train.py \ +accelerate launch examples/flux2/model_training/special/dmd2/train.py \ --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B \ --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B/metadata.csv \ --height 512 \ @@ -14,7 +14,7 @@ accelerate launch examples/flux2/model_training/train.py \ --num_epochs 10 \ --save_steps 1000 \ --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/FLUX.2-klein-base-4B_dmd2-4steps" \ + --output_path "./models/train/FLUX.2-klein-base-4B_dmd2" \ --trainable_models "dit" \ --task dmd2 \ --dmd2_student_sample_steps 4 \ @@ -24,32 +24,4 @@ accelerate launch examples/flux2/model_training/train.py \ --dmd2_feature_indices 12 \ --embedded_guidance 4 \ --dmd2_teacher_cfg_scale 4 \ - --use_gradient_checkpointing - - -# 使用更完整的训练数据 -# accelerate launch --num_processes 1 --num_machines 1 --mixed_precision bf16 --dynamo_backend no examples/flux2/model_training/train.py \ -# --dataset_base_path /mnt/nas1/yejinyan.yjy/diffsynth-study/DiffSynth-Studio/data/OmniData/midjourney-v6-520k-raw-diffsynth \ -# --dataset_metadata_path /mnt/nas1/yejinyan.yjy/diffsynth-study/DiffSynth-Studio/data/OmniData/midjourney-v6-520k-raw-diffsynth/metadata.csv \ - # --height 512 \ - # --width 512 \ - # --dataset_repeat 100 \ - # --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ - # --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ - # --learning_rate 1e-5 \ - # --dmd2_fake_score_learning_rate 1e-5 \ - # --dmd2_discriminator_learning_rate 1e-5 \ - # --num_epochs 10 \ - # --save_steps 1000 \ - # --remove_prefix_in_ckpt "pipe.dit." \ - # --output_path "./models/train/FLUX.2-klein-base-4B_dmd2-4steps" \ - # --trainable_models "dit" \ - # --task dmd2 \ - # --dmd2_student_sample_steps 4 \ - # --dmd2_student_sample_type sde \ - # --dmd2_student_update_freq 5 \ - # --dmd2_gan_loss_weight 0.03 \ - # --dmd2_feature_indices 12 \ - # --embedded_guidance 4 \ - # --dmd2_teacher_cfg_scale 4 \ - # --use_gradient_checkpointing + --use_gradient_checkpointing \ No newline at end of file diff --git a/examples/flux2/model_training/special/dmd2/train.py b/examples/flux2/model_training/special/dmd2/train.py new file mode 100644 index 000000000..c824a57dd --- /dev/null +++ b/examples/flux2/model_training/special/dmd2/train.py @@ -0,0 +1,514 @@ +import torch, os, argparse, accelerate +import argparse +import copy +import math + +from diffsynth.core import UnifiedDataset, gradient_checkpoint_forward +from diffsynth.diffusion import * +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def _parse_int_list(value): + if value is None or value == "": + return None + return [int(i) for i in value.split(",") if i != ""] + + +def _parse_float_list(value): + if value is None or value == "": + return None + return [float(i) for i in value.split(",") if i != ""] + + +def _get_optimal_groups(num_channels): + if num_channels <= 32: + groups = max(1, num_channels // 4) + else: + groups = 32 + while groups > 1 and num_channels % groups != 0: + groups -= 1 + assert num_channels % groups == 0, f"{num_channels} not divisible by {groups}" + return groups + + +class FluxDMD2Discriminator(torch.nn.Module): + def __init__(self, feature_indices=None, num_blocks=40, inner_dim=3072): + super().__init__() + if feature_indices is None: + feature_indices = [int(num_blocks // 2)] + self.feature_indices = sorted({int(i) for i in feature_indices if 0 <= int(i) < num_blocks}) + if len(self.feature_indices) == 0: + raise ValueError("DMD2 discriminator requires at least one valid feature index.") + self.num_features = len(self.feature_indices) + self.inner_dim = inner_dim + + hidden_channels = inner_dim // 2 + self.heads = torch.nn.ModuleList([ + torch.nn.Sequential( + torch.nn.Conv2d(inner_dim, hidden_channels, kernel_size=4, stride=2, padding=1), + torch.nn.GroupNorm(_get_optimal_groups(hidden_channels), hidden_channels), + torch.nn.LeakyReLU(0.2), + torch.nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, padding=0), + torch.nn.AdaptiveAvgPool2d((1, 1)), + torch.nn.Flatten(), + ) + for _ in self.feature_indices + ]) + + def forward(self, pipe, feats): + if not isinstance(feats, list) or len(feats) != self.num_features: + raise ValueError(f"Expected list of {self.num_features} feature tensors, got {type(feats)} with length {len(feats) if isinstance(feats, list) else 'N/A'}.") + logits = [] + for head, feat in zip(self.heads, feats): + param = next(head.parameters()) + feat = feat.to(device=param.device, dtype=param.dtype) + logits.append(head(feat)) + return torch.cat(logits, dim=1) + +def _model_fn_flux2_base( + dit, + latents=None, + timestep=None, + embedded_guidance=None, + prompt_embeds=None, + text_ids=None, + image_ids=None, + edit_latents=None, + edit_image_ids=None, + kv_cache=None, + extra_text_embedding=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + image_seq_len = latents.shape[1] + + if edit_latents is not None: + image_seq_len = latents.shape[1] + latents = torch.concat([latents, edit_latents], dim=1) + image_ids = torch.concat([image_ids, edit_image_ids], dim=1) + + if embedded_guidance is None: + embedded_guidance = None + elif isinstance(embedded_guidance, torch.Tensor): + embedded_guidance = embedded_guidance.to(device=latents.device, dtype=latents.dtype).flatten() + if embedded_guidance.numel() == 1: + embedded_guidance = embedded_guidance.expand(latents.shape[0]) + elif embedded_guidance.numel() != latents.shape[0]: + raise ValueError("`embedded_guidance` must be a scalar or match the latent batch size.") + else: + embedded_guidance = torch.full( + (latents.shape[0],), + float(embedded_guidance), + device=latents.device, + dtype=latents.dtype, + ) + + if extra_text_embedding is not None: + extra_text_ids = torch.zeros( + (1, extra_text_embedding.shape[1], 4), + dtype=text_ids.dtype, + device=text_ids.device, + ) + extra_text_ids[:, :, -1] = torch.arange( + prompt_embeds.shape[1], + prompt_embeds.shape[1] + extra_text_embedding.shape[1], + ) + prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + + model_output = dit( + hidden_states=latents, + timestep=timestep / 1000, + guidance=embedded_guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=image_ids, + kv_cache=kv_cache, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + model_output = model_output[:, :image_seq_len] + return model_output + + +def _model_fn_flux2_features( + dit, + latents=None, + timestep=None, + embedded_guidance=None, + prompt_embeds=None, + text_ids=None, + image_ids=None, + edit_latents=None, + edit_image_ids=None, + kv_cache=None, + extra_text_embedding=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + feature_indices=None, + **kwargs, +): + feature_indices = set() if feature_indices is None else set(feature_indices) + image_seq_len = latents.shape[1] + if edit_latents is not None: + image_seq_len = latents.shape[1] + latents = torch.concat([latents, edit_latents], dim=1) + image_ids = torch.concat([image_ids, edit_image_ids], dim=1) + if embedded_guidance is None: + embedded_guidance = None + elif isinstance(embedded_guidance, torch.Tensor): + embedded_guidance = embedded_guidance.to(device=latents.device, dtype=latents.dtype).flatten() + if embedded_guidance.numel() == 1: + embedded_guidance = embedded_guidance.expand(latents.shape[0]) + elif embedded_guidance.numel() != latents.shape[0]: + raise ValueError("`embedded_guidance` must be a scalar or match the latent batch size.") + else: + embedded_guidance = torch.full((latents.shape[0],), float(embedded_guidance), device=latents.device, dtype=latents.dtype) + if extra_text_embedding is not None: + extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device) + extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1]) + prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + + height, width = kwargs.get("height"), kwargs.get("width") + if height is not None and width is not None: + feature_height, feature_width = int(height) // 16, int(width) // 16 + else: + feature_height = int(math.sqrt(image_seq_len)) + feature_width = image_seq_len // feature_height if feature_height > 0 else 0 + if feature_height * feature_width != image_seq_len: + raise ValueError("Flux2 feature extraction requires height/width or square latent tokens.") + + features = [] + + def append_feature(feat): + feat = feat[:, :image_seq_len] + batch_size, _, channels = feat.shape + feat = feat.permute(0, 2, 1).reshape(batch_size, channels, feature_height, feature_width) + features.append(feat) + if len(features) == len(feature_indices): + return features + return None + + num_txt_tokens = prompt_embeds.shape[1] + timestep = timestep.to(latents.dtype) + guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000 + temb = dit.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = dit.double_stream_modulation_img(temb) + double_stream_mod_txt = dit.double_stream_modulation_txt(temb) + single_stream_mod = dit.single_stream_modulation(temb)[0] + + hidden_states = dit.x_embedder(latents) + encoder_hidden_states = dit.context_embedder(prompt_embeds) + + if image_ids.ndim == 3: + image_ids = image_ids[0] + if text_ids.ndim == 3: + text_ids = text_ids[0] + + image_rotary_emb = dit.pos_embed(image_ids) + text_rotary_emb = dit.pos_embed(text_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + for block_id, block in enumerate(dit.transformer_blocks): + encoder_hidden_states, hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=None, + kv_cache=None if kv_cache is None else kv_cache.get(f"double_{block_id}"), + ) + if block_id in feature_indices: + selected_features = append_feature(hidden_states) + if selected_features is not None: + return selected_features + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + num_double_blocks = len(dit.transformer_blocks) + + for block_id, block in enumerate(dit.single_transformer_blocks): + hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=None, + kv_cache=None if kv_cache is None else kv_cache.get(f"single_{block_id}"), + ) + feature_id = block_id + num_double_blocks + if feature_id in feature_indices: + selected_features = append_feature(hidden_states[:, num_txt_tokens:num_txt_tokens + image_seq_len]) + if selected_features is not None: + return selected_features + + if len(features) != len(feature_indices): + raise ValueError(f"Only collected {len(features)} feature maps for {len(feature_indices)} requested feature indices.") + return features + +def model_fn_flux2_dmd2( + pipe, + dit, + timestep, + progress_id, + num_inference_steps, + inputs_shared, + inputs_posi, + feature_indices=None, + return_features=False, +): + if not return_features: + return _model_fn_flux2_base( + dit=dit, + **inputs_shared, + **inputs_posi, + timestep=timestep, + progress_id=progress_id, + num_inference_steps=num_inference_steps, + ) + + return _model_fn_flux2_features( + dit=dit, + **inputs_shared, + **inputs_posi, + timestep=timestep, + progress_id=progress_id, + num_inference_steps=num_inference_steps, + feature_indices=feature_indices, + ) + + + +class Flux2DMD2TrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + embedded_guidance=1.0, + resume_from_checkpoint=None, + remove_prefix_in_ckpt=None, + dmd2_student_update_freq=5, + dmd2_student_sample_steps=4, + dmd2_student_sample_type="sde", + dmd2_student_schedule="uniform", + dmd2_student_t_list=None, + dmd2_matching_t_min=0.001, + dmd2_matching_t_max=0.999, + dmd2_matching_t_sampling="uniform", + dmd2_matching_t_mean=0.0, + dmd2_matching_t_std=1.0, + dmd2_gan_loss_weight=0.03, + dmd2_gan_r1_reg_weight=0.0, + dmd2_gan_r1_reg_alpha=0.1, + dmd2_fake_score_learning_rate=None, + dmd2_discriminator_learning_rate=None, + dmd2_feature_indices=None, + dmd2_teacher_cfg_scale=None, + dmd2_student_grad_clip_norm=10.0, + device="cpu", + task="dmd2", + ): + super().__init__() + feature_indices = _parse_int_list(dmd2_feature_indices) + config = DMD2Config( + student_update_freq=dmd2_student_update_freq, + student_sample_steps=dmd2_student_sample_steps, + student_sample_type=dmd2_student_sample_type, + student_schedule=dmd2_student_schedule, + student_t_list=_parse_float_list(dmd2_student_t_list), + matching_t_min=dmd2_matching_t_min, + matching_t_max=dmd2_matching_t_max, + matching_t_sampling=dmd2_matching_t_sampling, + matching_t_mean=dmd2_matching_t_mean, + matching_t_std=dmd2_matching_t_std, + gan_loss_weight=dmd2_gan_loss_weight, + gan_r1_reg_weight=dmd2_gan_r1_reg_weight, + gan_r1_reg_alpha=dmd2_gan_r1_reg_alpha, + fake_score_learning_rate=dmd2_fake_score_learning_rate, + discriminator_learning_rate=dmd2_discriminator_learning_rate, + feature_indices=feature_indices, + teacher_cfg_scale=dmd2_teacher_cfg_scale, + student_grad_clip_norm=dmd2_student_grad_clip_norm, + ) + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/")) + self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model, remove_unnecessary_params=True) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + self.pipe.dit_teacher = copy.deepcopy(self.pipe.dit) + self.pipe.dit_fake_score = copy.deepcopy(self.pipe.dit) + self.pipe.dmd2_discriminator = None + if config.gan_loss_weight > 0: + self.pipe.dmd2_discriminator = FluxDMD2Discriminator(feature_indices=config.feature_indices) + self.pipe.dit_teacher.eval().requires_grad_(False) + self.pipe.dit_fake_score.train().requires_grad_(True) + if self.pipe.dmd2_discriminator is not None: + self.pipe.dmd2_discriminator.train().requires_grad_(True) + self.resume_from_checkpoint(resume_from_checkpoint, remove_prefix_in_ckpt) + + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.embedded_guidance = embedded_guidance + self.task = task + self.dmd2_student_model_name = "dit" + self.dmd2_teacher_model_name = "dit_teacher" + self.dmd2_fake_score_model_name = "dit_fake_score" + self.dmd2_discriminator_model_name = "dmd2_discriminator" + self.dmd2_model_fn_student = model_fn_flux2_dmd2 + self.dmd2_model_fn_teacher = model_fn_flux2_dmd2 + self.dmd2_model_fn_fake_score = model_fn_flux2_dmd2 + self.dmd2_loss = DMD2Loss(config) + self._dmd2_student_param_names = {name for name, param in self.pipe.dit.named_parameters() if param.requires_grad} + self._dmd2_fake_score_param_names = {name for name, _ in self.pipe.dit_fake_score.named_parameters()} + self.task_to_loss = { + "dmd2": lambda pipe, inputs_shared, inputs_posi, inputs_nega, iteration=None: self.dmd2_loss( + self, + (inputs_shared, inputs_posi, inputs_nega), + 0 if iteration is None else iteration, + ), + } + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "embedded_guidance": self.embedded_guidance, + "cfg_scale": self.dmd2_loss.config.teacher_cfg_scale, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None, iteration=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + return self.task_to_loss[self.task](self.pipe, *inputs, iteration=iteration) + + def export_trainable_state_dict(self, state_dict, remove_prefix=None): + return export_dmd2_trainable_state_dict(self, state_dict, remove_prefix=remove_prefix) + +def flux2_dmd2_parser(): + parser = argparse.ArgumentParser(description="Flux.2 DMD2 training script.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser = add_dmd2_config(parser) + parser.set_defaults(task="dmd2") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--embedded_guidance", type=float, default=1.0, help="Flux.2 embedded guidance value.") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") + return parser + + +if __name__ == "__main__": + parser = flux2_dmd2_parser() + args = parser.parse_args() + args.find_unused_parameters = True + + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = Flux2DMD2TrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + embedded_guidance=args.embedded_guidance, + resume_from_checkpoint=args.resume_from_checkpoint, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + dmd2_student_update_freq=args.dmd2_student_update_freq, + dmd2_student_sample_steps=args.dmd2_student_sample_steps, + dmd2_student_sample_type=args.dmd2_student_sample_type, + dmd2_student_schedule=args.dmd2_student_schedule, + dmd2_student_t_list=args.dmd2_student_t_list, + dmd2_matching_t_min=args.dmd2_matching_t_min, + dmd2_matching_t_max=args.dmd2_matching_t_max, + dmd2_matching_t_sampling=args.dmd2_matching_t_sampling, + dmd2_matching_t_mean=args.dmd2_matching_t_mean, + dmd2_matching_t_std=args.dmd2_matching_t_std, + dmd2_gan_loss_weight=args.dmd2_gan_loss_weight, + dmd2_gan_r1_reg_weight=args.dmd2_gan_r1_reg_weight, + dmd2_gan_r1_reg_alpha=args.dmd2_gan_r1_reg_alpha, + dmd2_fake_score_learning_rate=args.dmd2_fake_score_learning_rate, + dmd2_discriminator_learning_rate=args.dmd2_discriminator_learning_rate, + dmd2_feature_indices=args.dmd2_feature_indices, + dmd2_teacher_cfg_scale=args.dmd2_teacher_cfg_scale, + dmd2_student_grad_clip_norm=args.dmd2_student_grad_clip_norm, + task=args.task, + device="cpu" if (args.initialize_model_on_cpu or args.enable_model_cpu_offload) else accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + enable_tensorboard_log=args.enable_tensorboard_log, + enable_swanlab_log=args.enable_swanlab_log, + swanlab_project=args.swanlab_project, + enable_wandb_log=args.enable_wandb_log, + wandb_project=args.wandb_project, + ) + launch_dmd2_training_task(accelerator, dataset, model, model_logger, args=args) diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 69baba46d..bd4938e65 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -18,11 +18,9 @@ def __init__( extra_inputs=None, fp8_models=None, offload_models=None, - embedded_guidance=1.0, template_model_id_or_path=None, resume_from_checkpoint=None, remove_prefix_in_ckpt=None, enable_lora_hot_loading=False, - dmd2_config=None, device="cpu", task="sft", ): @@ -49,10 +47,7 @@ def __init__( self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.fp8_models = fp8_models - self.embedded_guidance = embedded_guidance self.task = task - self.dmd2_config = None - self.task_to_input_processor = {} self.task_to_loss = { "sft:data_process": lambda pipe, *args: args, "direct_distill:data_process": lambda pipe, *args: args, @@ -61,9 +56,6 @@ def __init__( "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), } - self.task_to_state_dict_exporter = {} - if task == "dmd2": - setup_dmd2_training(self, self.pipe, dmd2_config) def get_pipeline_inputs(self, data): inputs_posi = {"prompt": data["prompt"]} @@ -76,38 +68,29 @@ def get_pipeline_inputs(self, data): "width": data["image"].size[0], # Please do not modify the following parameters # unless you clearly know what this will cause. - "embedded_guidance": self.embedded_guidance, + "embedded_guidance": 1.0, "cfg_scale": 1, "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, } - if self.task in self.task_to_input_processor: - inputs_shared = self.task_to_input_processor[self.task](inputs_shared) inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) return inputs_shared, inputs_posi, inputs_nega - def forward(self, data, inputs=None, iteration=None): + def forward(self, data, inputs=None): if inputs is None: inputs = self.get_pipeline_inputs(data) inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) for unit in self.pipe.units: inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) - loss = self.task_to_loss[self.task](self.pipe, *inputs, iteration=iteration) + loss = self.task_to_loss[self.task](self.pipe, *inputs) return loss - def export_trainable_state_dict(self, state_dict, remove_prefix=None): - if self.task in self.task_to_state_dict_exporter: - return self.task_to_state_dict_exporter[self.task](state_dict, remove_prefix=remove_prefix) - return super().export_trainable_state_dict(state_dict, remove_prefix=remove_prefix) - def flux2_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = add_general_config(parser) parser = add_image_size_config(parser) - parser = add_dmd2_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") - parser.add_argument("--embedded_guidance", type=float, default=1.0, help="Flux.2 embedded guidance value.") parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") return parser @@ -115,8 +98,6 @@ def flux2_parser(): if __name__ == "__main__": parser = flux2_parser() args = parser.parse_args() - if args.task == "dmd2": - args.find_unused_parameters = True accelerator = accelerate.Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -152,12 +133,10 @@ def flux2_parser(): extra_inputs=args.extra_inputs, fp8_models=args.fp8_models, offload_models=args.offload_models, - embedded_guidance=args.embedded_guidance, template_model_id_or_path=args.template_model_id_or_path, resume_from_checkpoint=args.resume_from_checkpoint, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, enable_lora_hot_loading=args.enable_lora_hot_loading, - dmd2_config=DMD2Config.from_args(args), task=args.task, device="cpu" if (args.initialize_model_on_cpu or args.enable_model_cpu_offload) else accelerator.device, ) @@ -177,6 +156,5 @@ def flux2_parser(): "sft:train": launch_training_task, "direct_distill": launch_training_task, "direct_distill:train": launch_training_task, - "dmd2": launch_dmd2_training_task, } - launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) \ No newline at end of file diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py index f269d4e2e..067f01901 100644 --- a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py @@ -12,9 +12,9 @@ ], tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), ) -state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_dmd2-4steps/step-2000.safetensors", torch_dtype=torch.bfloat16) +state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_dmd2/step-3000.safetensors", torch_dtype=torch.bfloat16) pipe.dit.load_state_dict(state_dict) prompt = "a dog" image = pipe(prompt=prompt, seed=0, num_inference_steps=4, height=512, width=512) -image.save("image_FLUX.2-klein-base-4B-DMD2-4steps.jpg") \ No newline at end of file +image.save("FLUX.2-klein-base-4B_dmd2.jpg") \ No newline at end of file