diff --git a/diffsynth/diffusion/__init__.py b/diffsynth/diffusion/__init__.py index 7823637e2..d285482de 100644 --- a/diffsynth/diffusion/__init__.py +++ b/diffsynth/diffusion/__init__.py @@ -4,3 +4,4 @@ from .runner import launch_training_task, launch_data_process_task from .parsers import * from .loss import * +from .dmd2 import * diff --git a/diffsynth/diffusion/dmd2.py b/diffsynth/diffusion/dmd2.py new file mode 100644 index 000000000..091d1ef17 --- /dev/null +++ b/diffsynth/diffusion/dmd2.py @@ -0,0 +1,601 @@ +from dataclasses import dataclass +from typing import Optional +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from tqdm import tqdm +from .runner import get_optimizer_class, initialize_deepspeed_gradient_checkpointing +from diffsynth.core import OffloadTrainingManager + + +@dataclass +class DMD2Config: + student_update_freq: int = 5 + student_sample_steps: int = 4 + student_sample_type: str = "sde" + student_schedule: str = "uniform" + student_t_list: Optional[list[float]] = None + matching_t_min: float = 0.001 + matching_t_max: float = 0.999 + matching_t_sampling: str = "uniform" + matching_t_mean: float = 0.0 + matching_t_std: float = 1.0 + gan_loss_weight: float = 0.03 + gan_r1_reg_weight: float = 0.0 + gan_r1_reg_alpha: float = 0.1 + fake_score_learning_rate: Optional[float] = None + discriminator_learning_rate: Optional[float] = None + feature_indices: Optional[list[int]] = None + teacher_cfg_scale: float = 1.0 + student_grad_clip_norm: Optional[float] = 10.0 + + +def _get_dmd2_pipe_model(module, model_name): + if model_name is None: + return None + return getattr(module.pipe, model_name) + + +def get_dmd2_student_model(module): + return _get_dmd2_pipe_model(module, module.dmd2_student_model_name) + + +def get_dmd2_teacher_model(module): + return _get_dmd2_pipe_model(module, module.dmd2_teacher_model_name) + + +def get_dmd2_fake_score_model(module): + return _get_dmd2_pipe_model(module, module.dmd2_fake_score_model_name) + + +def get_dmd2_discriminator(module): + return _get_dmd2_pipe_model(module, module.dmd2_discriminator_model_name) + + +def _dmd2_model_state_names(module, model_name, param_names=None): + if model_name is None: + return set() + model = getattr(module.pipe, model_name, None) + if model is None: + return set() + names = model.state_dict().keys() if param_names is None else param_names + return {f"pipe.{model_name}.{name}" for name in names} + + +def export_dmd2_trainable_state_dict(module, state_dict, remove_prefix=None): + student_names = _dmd2_model_state_names(module, module.dmd2_student_model_name, module._dmd2_student_param_names) + state_names = set(student_names) + if remove_prefix is None: + state_names.update(_dmd2_model_state_names(module, module.dmd2_teacher_model_name)) + state_names.update(_dmd2_model_state_names(module, module.dmd2_fake_score_model_name)) + state_names.update(_dmd2_model_state_names(module, module.dmd2_discriminator_model_name)) + state_dict = {name: param for name, param in state_dict.items() if name in state_names} + if remove_prefix is not None: + state_dict = {name[len(remove_prefix):] if name.startswith(remove_prefix) else name: param for name, param in state_dict.items()} + return state_dict + + +def set_dmd2_train_phase(module, student_phase: bool): + student_model = get_dmd2_student_model(module) + student_model.train(student_phase) + for name, param in student_model.named_parameters(): + param.requires_grad = student_phase and name in module._dmd2_student_param_names + + fake_score_model = get_dmd2_fake_score_model(module) + fake_score_model.train(not student_phase) + for name, param in fake_score_model.named_parameters(): + param.requires_grad = (not student_phase) and name in module._dmd2_fake_score_param_names + + discriminator = get_dmd2_discriminator(module) + if discriminator is not None: + discriminator.train(not student_phase) + discriminator.requires_grad_(not student_phase) + + +def _expand_like(value, target, dtype=None): + if dtype is None: + dtype = target.dtype + if not isinstance(value, torch.Tensor): + value = torch.tensor(value, device=target.device, dtype=dtype) + value = value.to(device=target.device, dtype=dtype) + while value.ndim < target.ndim: + value = value.view(*value.shape, 1) + return value + + +def flow_to_x0(latents, flow, sigma): + original_dtype = latents.dtype + latents = latents.to(torch.float64) + flow = flow.to(torch.float64) + sigma = _expand_like(sigma, latents) + return (latents - sigma * flow).to(original_dtype) + + +def _forward_process(x0, eps, sigma): + original_dtype = x0.dtype + x0 = x0.to(torch.float64) + eps = eps.to(torch.float64) + sigma = _expand_like(sigma, x0) + return ((1 - sigma) * x0 + sigma * eps).to(original_dtype) + + +def _scale_noise(noise, sigma): + original_dtype = noise.dtype + noise = noise.to(torch.float64) + sigma = _expand_like(sigma, noise) + return (noise * sigma).to(original_dtype) + + +def make_dmd2_student_schedule( + pipe, + num_steps, + device, + student_schedule="uniform", + student_t_list=None, + matching_t_max=0.999, +): + time_dtype = torch.float64 + if student_t_list is not None: + sigmas = torch.tensor(student_t_list, device=device, dtype=time_dtype) + if sigmas[-1].item() != 0: + raise ValueError("`dmd2_student_t_list` must include a final 0.") + if len(sigmas) != num_steps + 1: + raise ValueError("The student sigma schedule length must equal `student_sample_steps + 1`.") + timesteps = sigmas * pipe.scheduler.num_train_timesteps + return sigmas, timesteps + if student_schedule == "uniform": + sigma_start = min(float(matching_t_max), 0.999) + sigmas = torch.linspace(sigma_start, 0.0, num_steps + 1, device=device, dtype=time_dtype) + timesteps = sigmas * pipe.scheduler.num_train_timesteps + return sigmas, timesteps + raise ValueError(f"Unsupported DMD2 student schedule: {student_schedule}") + + +def _variational_score_distillation_loss(gen_data, teacher_x0, fake_score_x0): + dims = tuple(range(1, teacher_x0.ndim)) + with torch.no_grad(): + weight = 1 / ((gen_data.float() - teacher_x0.float()).abs().mean(dim=dims, keepdim=True) + 1e-6) + weight = weight.to(dtype=gen_data.dtype) + pseudo_target = gen_data - (fake_score_x0 - teacher_x0) * weight + loss = 0.5 * F.mse_loss(gen_data.float(), pseudo_target.float(), reduction="mean") + return loss + + +def _gan_loss_generator(fake_logits): + assert fake_logits.ndim == 2, f"fake_logits has shape {fake_logits.shape}" + gan_loss = F.softplus(-fake_logits).mean() + return gan_loss + + +def _gan_loss_discriminator(real_logits, fake_logits): + assert fake_logits.ndim == 2, f"fake_logits has shape {fake_logits.shape}" + assert real_logits.ndim == 2, f"real_logits has shape {real_logits.shape}" + gan_loss = F.softplus(fake_logits).mean() + F.softplus(-real_logits).mean() + return gan_loss + + +class DMD2Loss: + def __init__(self, config: DMD2Config): + self.config = config + + def _model_forward_x0( + self, + module, + model, + model_fn, + latents, + timestep, + sigma, + inputs_shared, + inputs_posi, + ): + shared = dict(inputs_shared) + posi = dict(inputs_posi) + shared["latents"] = latents + flow = model_fn( + module.pipe, + model, + timestep=timestep, + progress_id=0, + num_inference_steps=1, + inputs_shared=shared, + inputs_posi=posi, + ) + return flow_to_x0(latents, flow, sigma) + + def _model_forward_features( + self, + module, + model, + model_fn, + latents, + timestep, + inputs_shared, + inputs_posi, + ): + discriminator = get_dmd2_discriminator(module) + if discriminator is None: + raise ValueError("DMD2 feature extraction requires a discriminator.") + if self.config.feature_indices is None: + raise ValueError("DMD2 feature extraction requires `dmd2_feature_indices`.") + shared = dict(inputs_shared) + posi = dict(inputs_posi) + shared["latents"] = latents + return model_fn( + module.pipe, + model, + timestep=timestep, + progress_id=0, + num_inference_steps=1, + inputs_shared=shared, + inputs_posi=posi, + feature_indices=set(self.config.feature_indices), + return_features=True, + ) + + def _teacher_forward_x0( + self, + module, + latents, + timestep, + sigma, + inputs_shared, + inputs_posi, + inputs_nega, + ): + teacher_x0_pos = self._model_forward_x0( + module, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, + latents, + timestep, + sigma, + inputs_shared, + inputs_posi, + ) + if self.config.teacher_cfg_scale <= 1.0: + return teacher_x0_pos + + + x0_neg = self._model_forward_x0( + module, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, + latents, + timestep, + sigma, + inputs_shared, + inputs_nega, + ) + return x0_neg + float(self.config.teacher_cfg_scale) * (teacher_x0_pos - x0_neg) + + def _sample_matching_timestep(self, pipe, device, dtype, batch_size=1, inputs_shared=None, real_data=None): + if self.config.matching_t_min > self.config.matching_t_max: + raise ValueError("`dmd2_matching_t_min` must be <= `dmd2_matching_t_max`.") + time_dtype = torch.float64 + if self.config.matching_t_sampling == "uniform": + sigma = torch.rand(batch_size, device=device, dtype=time_dtype) + sigma = sigma * (self.config.matching_t_max - self.config.matching_t_min) + self.config.matching_t_min + timestep = sigma * pipe.scheduler.num_train_timesteps + return timestep, sigma + if self.config.matching_t_sampling == "logitnormal": + sigma = torch.sigmoid( + torch.randn(batch_size, device=device, dtype=time_dtype) * self.config.matching_t_std + + self.config.matching_t_mean + ) + sigma = sigma * (self.config.matching_t_max - self.config.matching_t_min) + self.config.matching_t_min + sigma = sigma.clamp(self.config.matching_t_min, self.config.matching_t_max) + timestep = sigma * pipe.scheduler.num_train_timesteps + return timestep, sigma + raise ValueError(f"Unsupported DMD2 matching timestep sampling: {self.config.matching_t_sampling}") + + def _generate_student_data(self, module, real_data, inputs_shared, inputs_posi): + pipe = module.pipe + device, dtype = real_data.device, real_data.dtype + batch_size = real_data.shape[0] + student_sigmas, student_timesteps = make_dmd2_student_schedule( + pipe, + self.config.student_sample_steps, + device, + student_schedule=self.config.student_schedule, + student_t_list=self.config.student_t_list, + matching_t_max=self.config.matching_t_max, + ) + if len(student_sigmas) != self.config.student_sample_steps + 1: + raise ValueError("The student sigma schedule length must equal `student_sample_steps + 1`.") + + if self.config.student_sample_steps == 1: + timestep = student_timesteps[0:1].expand(batch_size) + sigma = student_sigmas[0:1].expand(batch_size) + input_student = _scale_noise(torch.randn_like(real_data), sigma) + else: + step_id = torch.randint(0, self.config.student_sample_steps, (batch_size,), device=device) + sigma = student_sigmas[step_id] + timestep = student_timesteps[step_id] + eps_student = torch.randn_like(real_data) + input_student = _forward_process(real_data, eps_student, sigma) + gen_data = self._model_forward_x0( + module, + get_dmd2_student_model(module), + module.dmd2_model_fn_student, + input_student, + timestep, + sigma, + inputs_shared, + inputs_posi, + ) + return gen_data, input_student + + def _compute_real_feat(self, module, real_data, timestep, sigma, eps, inputs_shared, inputs_posi): + real_timestep, real_sigma, real_eps = timestep, sigma, eps + perturbed_real = _forward_process(real_data, real_eps, real_sigma) + real_feat = self._model_forward_features( + module, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, + perturbed_real, + real_timestep, + inputs_shared, + inputs_posi, + ) + return real_feat, real_timestep, real_sigma + + def _student_update_step(self, module, real_data, inputs_shared, inputs_posi, inputs_nega): + gen_data, input_student = self._generate_student_data(module, real_data, inputs_shared, inputs_posi) + timestep, sigma = self._sample_matching_timestep( + module.pipe, + real_data.device, + real_data.dtype, + real_data.shape[0], + inputs_shared=inputs_shared, + real_data=real_data, + ) + eps = torch.randn_like(real_data) + perturbed_data = _forward_process(gen_data, eps, sigma) + + with torch.no_grad(): + fake_score_x0 = self._model_forward_x0( + module, + get_dmd2_fake_score_model(module), + module.dmd2_model_fn_fake_score, + perturbed_data, + timestep, + sigma, + inputs_shared, + inputs_posi, + ) + + with torch.no_grad(): + teacher_x0 = self._teacher_forward_x0( + module, perturbed_data, timestep, sigma, inputs_shared, inputs_posi, inputs_nega + ) + if self.config.gan_loss_weight > 0: + fake_feat = self._model_forward_features( + module, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, + perturbed_data, + timestep, + inputs_shared, + inputs_posi, + ) + fake_logits_gen = get_dmd2_discriminator(module)(module.pipe, fake_feat) + gan_loss_gen = _gan_loss_generator(fake_logits_gen) + else: + gan_loss_gen = torch.zeros((), device=real_data.device, dtype=torch.float32) + + vsd_loss = _variational_score_distillation_loss(gen_data, teacher_x0.detach(), fake_score_x0) + gan_loss_weighted = self.config.gan_loss_weight * gan_loss_gen + loss = vsd_loss + gan_loss_weighted + + return {"total_loss": loss} + + def _fake_score_discriminator_update_step(self, module, real_data, inputs_shared, inputs_posi, inputs_nega): + with torch.no_grad(): + gen_data, _ = self._generate_student_data(module, real_data, inputs_shared, inputs_posi) + timestep, sigma = self._sample_matching_timestep( + module.pipe, + real_data.device, + real_data.dtype, + real_data.shape[0], + inputs_shared=inputs_shared, + real_data=real_data, + ) + eps = torch.randn_like(real_data) + x_t_sg = _forward_process(gen_data, eps, sigma) + + fake_score_x0 = self._model_forward_x0( + module, + get_dmd2_fake_score_model(module), + module.dmd2_model_fn_fake_score, + x_t_sg, + timestep, + sigma, + inputs_shared, + inputs_posi, + ) + fake_score_loss = F.mse_loss(fake_score_x0.float(), gen_data.float(), reduction="mean") + + gan_loss_disc = torch.zeros_like(fake_score_loss) + gan_loss_ar1 = torch.zeros_like(fake_score_loss) + gan_loss_logit_reg = torch.zeros_like(fake_score_loss) + if self.config.gan_loss_weight > 0: + with torch.no_grad(): + fake_feat = self._model_forward_features( + module, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, + x_t_sg, + timestep, + inputs_shared, + inputs_posi, + ) + real_feat, real_timestep, real_sigma = self._compute_real_feat( + module, real_data, timestep, sigma, eps, inputs_shared, inputs_posi + ) + discriminator = get_dmd2_discriminator(module) + real_logits = discriminator(module.pipe, real_feat) + fake_logits = discriminator(module.pipe, fake_feat) + gan_loss_disc = _gan_loss_discriminator(real_logits, fake_logits) + if self.config.gan_r1_reg_weight > 0: + perturbed_real_alpha = real_data + self.config.gan_r1_reg_alpha * torch.randn_like(real_data) + with torch.no_grad(): + real_feat_alpha = self._model_forward_features( + module, + get_dmd2_teacher_model(module), + module.dmd2_model_fn_teacher, + perturbed_real_alpha, + real_timestep, + inputs_shared, + inputs_posi, + ) + real_logits_alpha = discriminator(module.pipe, real_feat_alpha) + gan_loss_ar1 = F.mse_loss(real_logits, real_logits_alpha, reduction="mean") + + loss = ( + fake_score_loss + + gan_loss_disc + + self.config.gan_r1_reg_weight * gan_loss_ar1 + ) + return {"total_loss": loss} + + def __call__(self, module, inputs, iteration): + inputs_shared, inputs_posi, inputs_nega = inputs + real_data = inputs_shared.get("input_latents") + if real_data is None: + raise ValueError("DMD2 requires image latents from the dataset. Please provide training images.") + student_phase = iteration % self.config.student_update_freq == 0 + set_dmd2_train_phase(module, student_phase) + if student_phase: + return self._student_update_step(module, real_data, inputs_shared, inputs_posi, inputs_nega) + return self._fake_score_discriminator_update_step(module, real_data, inputs_shared, inputs_posi, inputs_nega) + + +def _trainable_params(module): + return [param for param in module.parameters() if param.requires_grad] + + +def _dmd2_current_optimizers(config, optimizers, iteration): + if iteration % config.student_update_freq == 0: + return [optimizers["student"]], [optimizers["student_scheduler"]] + current_optimizers = [optimizers["fake_score"]] + current_schedulers = [optimizers["fake_score_scheduler"]] + if "discriminator" in optimizers: + current_optimizers.append(optimizers["discriminator"]) + current_schedulers.append(optimizers["discriminator_scheduler"]) + return current_optimizers, current_schedulers + + +def launch_dmd2_training_task( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model, + model_logger, + learning_rate: float = 1e-5, + weight_decay: float = 1e-2, + num_workers: int = 1, + save_steps: int = None, + num_epochs: int = 1, + enable_model_cpu_offload: bool = False, + enable_optimizer_cpu_offload: bool = False, + cpu_offload_split_threshold: int = None, + customized_optimizer: str = None, + args=None, + **kwargs, +): + if args is not None: + learning_rate = args.learning_rate + weight_decay = args.weight_decay + num_workers = args.dataset_num_workers + save_steps = args.save_steps + num_epochs = args.num_epochs + enable_model_cpu_offload = args.enable_model_cpu_offload + enable_optimizer_cpu_offload = args.enable_optimizer_cpu_offload + cpu_offload_split_threshold = args.cpu_offload_split_threshold + customized_optimizer = args.customized_optimizer + + optimizer_class = get_optimizer_class(customized_optimizer) + config = model.dmd2_loss.config + fake_score_lr = config.fake_score_learning_rate or learning_rate + discriminator_lr = config.discriminator_learning_rate or learning_rate + + student_model = get_dmd2_student_model(model) + fake_score_model = get_dmd2_fake_score_model(model) + discriminator = get_dmd2_discriminator(model) + has_discriminator = discriminator is not None + + student_optimizer = optimizer_class(_trainable_params(student_model), lr=learning_rate, weight_decay=weight_decay) + fake_score_optimizer = optimizer_class(fake_score_model.parameters(), lr=fake_score_lr, weight_decay=weight_decay) + student_scheduler = torch.optim.lr_scheduler.ConstantLR(student_optimizer, factor=1.0, total_iters=1) + fake_score_scheduler = torch.optim.lr_scheduler.ConstantLR(fake_score_optimizer, factor=1.0, total_iters=1) + + optimizers = { + "student": student_optimizer, + "fake_score": fake_score_optimizer, + "student_scheduler": student_scheduler, + "fake_score_scheduler": fake_score_scheduler, + } + prepare_items = [model, student_optimizer, fake_score_optimizer, student_scheduler, fake_score_scheduler] + if has_discriminator: + discriminator_optimizer = optimizer_class(discriminator.parameters(), lr=discriminator_lr, weight_decay=weight_decay) + discriminator_scheduler = torch.optim.lr_scheduler.ConstantLR(discriminator_optimizer, factor=1.0, total_iters=1) + optimizers["discriminator"] = discriminator_optimizer + optimizers["discriminator_scheduler"] = discriminator_scheduler + prepare_items.extend([discriminator_optimizer, discriminator_scheduler]) + + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + prepare_items.append(dataloader) + + if enable_model_cpu_offload: + prepared = accelerator.prepare(*prepare_items[1:]) + model.pipe.device = accelerator.device + offload_manager = OffloadTrainingManager( + model, + accelerator.device, + enable_optimizer_cpu_offload, + cpu_offload_split_threshold, + ) + prepared_tail = list(prepared) + else: + model.to(device=accelerator.device) + prepared = accelerator.prepare(*prepare_items) + model = prepared[0] + prepared_tail = list(prepared[1:]) + + optimizers["student"] = prepared_tail.pop(0) + optimizers["fake_score"] = prepared_tail.pop(0) + optimizers["student_scheduler"] = prepared_tail.pop(0) + optimizers["fake_score_scheduler"] = prepared_tail.pop(0) + if has_discriminator: + optimizers["discriminator"] = prepared_tail.pop(0) + optimizers["discriminator_scheduler"] = prepared_tail.pop(0) + dataloader = prepared_tail.pop(0) + + initialize_deepspeed_gradient_checkpointing(accelerator) + iteration = 0 + for epoch_id in range(num_epochs): + for data in tqdm(dataloader): + with accelerator.accumulate(model): + if dataset.load_from_cache: + loss_map = model({}, inputs=data, iteration=iteration) + else: + loss_map = model(data, iteration=iteration) + loss = loss_map["total_loss"] + current_optimizers, current_schedulers = _dmd2_current_optimizers(config, optimizers, iteration) + accelerator.backward(loss) + if enable_model_cpu_offload: + offload_manager.after_backward() + if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0: + accelerator.clip_grad_norm_(_trainable_params(get_dmd2_student_model(accelerator.unwrap_model(model))), config.student_grad_clip_norm) + for optimizer in current_optimizers: + optimizer.step() + for scheduler in current_schedulers: + scheduler.step() + for optimizer in current_optimizers: + optimizer.zero_grad(set_to_none=True) + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) + iteration += 1 + if save_steps is None: + model_logger.on_epoch_end(accelerator, model, epoch_id) + + model_logger.on_training_end(accelerator, model, save_steps) \ No newline at end of file diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index c2c0b6ef7..cc190516b 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -81,6 +81,27 @@ def add_logger_config(parser: argparse.ArgumentParser): parser.add_argument("--wandb_project", type=str, default="DiffSynth-Studio", help="Wandb project name.") return parser +def add_dmd2_config(parser: argparse.ArgumentParser): + parser.add_argument("--dmd2_student_update_freq", type=int, default=5, help="Update student once every N DMD2 iterations.") + parser.add_argument("--dmd2_student_sample_steps", type=int, default=4, help="Number of distilled student sampling steps.") + parser.add_argument("--dmd2_student_sample_type", type=str, default="sde", choices=["sde", "ode"], help="Student sampling type.") + parser.add_argument("--dmd2_student_schedule", type=str, default="uniform", choices=["uniform"], help="Student sigma schedule.") + parser.add_argument("--dmd2_student_t_list", type=str, default=None, help="Optional student sigma schedule, including the final 0.") + parser.add_argument("--dmd2_matching_t_min", type=float, default=0.001, help="Minimum matching sigma sampled for DMD2.") + parser.add_argument("--dmd2_matching_t_max", type=float, default=0.999, help="Maximum matching sigma sampled for DMD2.") + parser.add_argument("--dmd2_matching_t_sampling", type=str, default="uniform", choices=["uniform", "logitnormal"], help="Sample matching sigma.") + parser.add_argument("--dmd2_matching_t_mean", type=float, default=0.0, help="Mean for logitnormal matching timestep sampling.") + parser.add_argument("--dmd2_matching_t_std", type=float, default=1.0, help="Std for logitnormal matching timestep sampling.") + parser.add_argument("--dmd2_gan_loss_weight", type=float, default=0.03, help="Generator GAN loss weight.") + parser.add_argument("--dmd2_gan_r1_reg_weight", type=float, default=0.0, help="Approximate R1 regularization weight for the discriminator.") + parser.add_argument("--dmd2_gan_r1_reg_alpha", type=float, default=0.1, help="Noise scale for approximate R1 regularization.") + parser.add_argument("--dmd2_fake_score_learning_rate", type=float, default=None, help="Learning rate for the fake score model.") + parser.add_argument("--dmd2_discriminator_learning_rate", type=float, default=None, help="Learning rate for the discriminator.") + parser.add_argument("--dmd2_feature_indices", type=str, default=None, help="Model feature indices used by the DMD2 discriminator.") + parser.add_argument("--dmd2_teacher_cfg_scale", type=float, default=1.0, help="CFG scale.") + parser.add_argument("--dmd2_student_grad_clip_norm", type=float, default=10.0, help="Clip student gradients to this norm.") + return parser + def add_general_config(parser: argparse.ArgumentParser): parser = add_dataset_base_config(parser) parser = add_model_config(parser) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 7807ff970..753f6c086 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -641,4 +641,4 @@ def model_fn_flux2( use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) model_output = model_output[:, :image_seq_len] - return model_output + return model_output \ No newline at end of file diff --git a/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh new file mode 100644 index 000000000..630764cfc --- /dev/null +++ b/examples/flux2/model_training/special/dmd2/FLUX.2-klein-base-4B-DMD2.sh @@ -0,0 +1,27 @@ +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/FLUX.2-klein-base-4B/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch examples/flux2/model_training/special/dmd2/train.py \ + --dataset_base_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B \ + --dataset_metadata_path data/diffsynth_example_dataset/flux2/FLUX.2-klein-base-4B/metadata.csv \ + --height 512 \ + --width 512 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ + --learning_rate 1e-5 \ + --dmd2_fake_score_learning_rate 1e-5 \ + --dmd2_discriminator_learning_rate 1e-5 \ + --num_epochs 10 \ + --save_steps 1000 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-base-4B_dmd2" \ + --trainable_models "dit" \ + --task dmd2 \ + --dmd2_student_sample_steps 4 \ + --dmd2_student_sample_type sde \ + --dmd2_student_update_freq 5 \ + --dmd2_gan_loss_weight 0.03 \ + --dmd2_feature_indices 12 \ + --embedded_guidance 4 \ + --dmd2_teacher_cfg_scale 4 \ + --use_gradient_checkpointing \ No newline at end of file diff --git a/examples/flux2/model_training/special/dmd2/train.py b/examples/flux2/model_training/special/dmd2/train.py new file mode 100644 index 000000000..c824a57dd --- /dev/null +++ b/examples/flux2/model_training/special/dmd2/train.py @@ -0,0 +1,514 @@ +import torch, os, argparse, accelerate +import argparse +import copy +import math + +from diffsynth.core import UnifiedDataset, gradient_checkpoint_forward +from diffsynth.diffusion import * +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def _parse_int_list(value): + if value is None or value == "": + return None + return [int(i) for i in value.split(",") if i != ""] + + +def _parse_float_list(value): + if value is None or value == "": + return None + return [float(i) for i in value.split(",") if i != ""] + + +def _get_optimal_groups(num_channels): + if num_channels <= 32: + groups = max(1, num_channels // 4) + else: + groups = 32 + while groups > 1 and num_channels % groups != 0: + groups -= 1 + assert num_channels % groups == 0, f"{num_channels} not divisible by {groups}" + return groups + + +class FluxDMD2Discriminator(torch.nn.Module): + def __init__(self, feature_indices=None, num_blocks=40, inner_dim=3072): + super().__init__() + if feature_indices is None: + feature_indices = [int(num_blocks // 2)] + self.feature_indices = sorted({int(i) for i in feature_indices if 0 <= int(i) < num_blocks}) + if len(self.feature_indices) == 0: + raise ValueError("DMD2 discriminator requires at least one valid feature index.") + self.num_features = len(self.feature_indices) + self.inner_dim = inner_dim + + hidden_channels = inner_dim // 2 + self.heads = torch.nn.ModuleList([ + torch.nn.Sequential( + torch.nn.Conv2d(inner_dim, hidden_channels, kernel_size=4, stride=2, padding=1), + torch.nn.GroupNorm(_get_optimal_groups(hidden_channels), hidden_channels), + torch.nn.LeakyReLU(0.2), + torch.nn.Conv2d(hidden_channels, 1, kernel_size=1, stride=1, padding=0), + torch.nn.AdaptiveAvgPool2d((1, 1)), + torch.nn.Flatten(), + ) + for _ in self.feature_indices + ]) + + def forward(self, pipe, feats): + if not isinstance(feats, list) or len(feats) != self.num_features: + raise ValueError(f"Expected list of {self.num_features} feature tensors, got {type(feats)} with length {len(feats) if isinstance(feats, list) else 'N/A'}.") + logits = [] + for head, feat in zip(self.heads, feats): + param = next(head.parameters()) + feat = feat.to(device=param.device, dtype=param.dtype) + logits.append(head(feat)) + return torch.cat(logits, dim=1) + +def _model_fn_flux2_base( + dit, + latents=None, + timestep=None, + embedded_guidance=None, + prompt_embeds=None, + text_ids=None, + image_ids=None, + edit_latents=None, + edit_image_ids=None, + kv_cache=None, + extra_text_embedding=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + image_seq_len = latents.shape[1] + + if edit_latents is not None: + image_seq_len = latents.shape[1] + latents = torch.concat([latents, edit_latents], dim=1) + image_ids = torch.concat([image_ids, edit_image_ids], dim=1) + + if embedded_guidance is None: + embedded_guidance = None + elif isinstance(embedded_guidance, torch.Tensor): + embedded_guidance = embedded_guidance.to(device=latents.device, dtype=latents.dtype).flatten() + if embedded_guidance.numel() == 1: + embedded_guidance = embedded_guidance.expand(latents.shape[0]) + elif embedded_guidance.numel() != latents.shape[0]: + raise ValueError("`embedded_guidance` must be a scalar or match the latent batch size.") + else: + embedded_guidance = torch.full( + (latents.shape[0],), + float(embedded_guidance), + device=latents.device, + dtype=latents.dtype, + ) + + if extra_text_embedding is not None: + extra_text_ids = torch.zeros( + (1, extra_text_embedding.shape[1], 4), + dtype=text_ids.dtype, + device=text_ids.device, + ) + extra_text_ids[:, :, -1] = torch.arange( + prompt_embeds.shape[1], + prompt_embeds.shape[1] + extra_text_embedding.shape[1], + ) + prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + + model_output = dit( + hidden_states=latents, + timestep=timestep / 1000, + guidance=embedded_guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=image_ids, + kv_cache=kv_cache, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + model_output = model_output[:, :image_seq_len] + return model_output + + +def _model_fn_flux2_features( + dit, + latents=None, + timestep=None, + embedded_guidance=None, + prompt_embeds=None, + text_ids=None, + image_ids=None, + edit_latents=None, + edit_image_ids=None, + kv_cache=None, + extra_text_embedding=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + feature_indices=None, + **kwargs, +): + feature_indices = set() if feature_indices is None else set(feature_indices) + image_seq_len = latents.shape[1] + if edit_latents is not None: + image_seq_len = latents.shape[1] + latents = torch.concat([latents, edit_latents], dim=1) + image_ids = torch.concat([image_ids, edit_image_ids], dim=1) + if embedded_guidance is None: + embedded_guidance = None + elif isinstance(embedded_guidance, torch.Tensor): + embedded_guidance = embedded_guidance.to(device=latents.device, dtype=latents.dtype).flatten() + if embedded_guidance.numel() == 1: + embedded_guidance = embedded_guidance.expand(latents.shape[0]) + elif embedded_guidance.numel() != latents.shape[0]: + raise ValueError("`embedded_guidance` must be a scalar or match the latent batch size.") + else: + embedded_guidance = torch.full((latents.shape[0],), float(embedded_guidance), device=latents.device, dtype=latents.dtype) + if extra_text_embedding is not None: + extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device) + extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1]) + prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + + height, width = kwargs.get("height"), kwargs.get("width") + if height is not None and width is not None: + feature_height, feature_width = int(height) // 16, int(width) // 16 + else: + feature_height = int(math.sqrt(image_seq_len)) + feature_width = image_seq_len // feature_height if feature_height > 0 else 0 + if feature_height * feature_width != image_seq_len: + raise ValueError("Flux2 feature extraction requires height/width or square latent tokens.") + + features = [] + + def append_feature(feat): + feat = feat[:, :image_seq_len] + batch_size, _, channels = feat.shape + feat = feat.permute(0, 2, 1).reshape(batch_size, channels, feature_height, feature_width) + features.append(feat) + if len(features) == len(feature_indices): + return features + return None + + num_txt_tokens = prompt_embeds.shape[1] + timestep = timestep.to(latents.dtype) + guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000 + temb = dit.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = dit.double_stream_modulation_img(temb) + double_stream_mod_txt = dit.double_stream_modulation_txt(temb) + single_stream_mod = dit.single_stream_modulation(temb)[0] + + hidden_states = dit.x_embedder(latents) + encoder_hidden_states = dit.context_embedder(prompt_embeds) + + if image_ids.ndim == 3: + image_ids = image_ids[0] + if text_ids.ndim == 3: + text_ids = text_ids[0] + + image_rotary_emb = dit.pos_embed(image_ids) + text_rotary_emb = dit.pos_embed(text_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + for block_id, block in enumerate(dit.transformer_blocks): + encoder_hidden_states, hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=None, + kv_cache=None if kv_cache is None else kv_cache.get(f"double_{block_id}"), + ) + if block_id in feature_indices: + selected_features = append_feature(hidden_states) + if selected_features is not None: + return selected_features + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + num_double_blocks = len(dit.transformer_blocks) + + for block_id, block in enumerate(dit.single_transformer_blocks): + hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=None, + kv_cache=None if kv_cache is None else kv_cache.get(f"single_{block_id}"), + ) + feature_id = block_id + num_double_blocks + if feature_id in feature_indices: + selected_features = append_feature(hidden_states[:, num_txt_tokens:num_txt_tokens + image_seq_len]) + if selected_features is not None: + return selected_features + + if len(features) != len(feature_indices): + raise ValueError(f"Only collected {len(features)} feature maps for {len(feature_indices)} requested feature indices.") + return features + +def model_fn_flux2_dmd2( + pipe, + dit, + timestep, + progress_id, + num_inference_steps, + inputs_shared, + inputs_posi, + feature_indices=None, + return_features=False, +): + if not return_features: + return _model_fn_flux2_base( + dit=dit, + **inputs_shared, + **inputs_posi, + timestep=timestep, + progress_id=progress_id, + num_inference_steps=num_inference_steps, + ) + + return _model_fn_flux2_features( + dit=dit, + **inputs_shared, + **inputs_posi, + timestep=timestep, + progress_id=progress_id, + num_inference_steps=num_inference_steps, + feature_indices=feature_indices, + ) + + + +class Flux2DMD2TrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + embedded_guidance=1.0, + resume_from_checkpoint=None, + remove_prefix_in_ckpt=None, + dmd2_student_update_freq=5, + dmd2_student_sample_steps=4, + dmd2_student_sample_type="sde", + dmd2_student_schedule="uniform", + dmd2_student_t_list=None, + dmd2_matching_t_min=0.001, + dmd2_matching_t_max=0.999, + dmd2_matching_t_sampling="uniform", + dmd2_matching_t_mean=0.0, + dmd2_matching_t_std=1.0, + dmd2_gan_loss_weight=0.03, + dmd2_gan_r1_reg_weight=0.0, + dmd2_gan_r1_reg_alpha=0.1, + dmd2_fake_score_learning_rate=None, + dmd2_discriminator_learning_rate=None, + dmd2_feature_indices=None, + dmd2_teacher_cfg_scale=None, + dmd2_student_grad_clip_norm=10.0, + device="cpu", + task="dmd2", + ): + super().__init__() + feature_indices = _parse_int_list(dmd2_feature_indices) + config = DMD2Config( + student_update_freq=dmd2_student_update_freq, + student_sample_steps=dmd2_student_sample_steps, + student_sample_type=dmd2_student_sample_type, + student_schedule=dmd2_student_schedule, + student_t_list=_parse_float_list(dmd2_student_t_list), + matching_t_min=dmd2_matching_t_min, + matching_t_max=dmd2_matching_t_max, + matching_t_sampling=dmd2_matching_t_sampling, + matching_t_mean=dmd2_matching_t_mean, + matching_t_std=dmd2_matching_t_std, + gan_loss_weight=dmd2_gan_loss_weight, + gan_r1_reg_weight=dmd2_gan_r1_reg_weight, + gan_r1_reg_alpha=dmd2_gan_r1_reg_alpha, + fake_score_learning_rate=dmd2_fake_score_learning_rate, + discriminator_learning_rate=dmd2_discriminator_learning_rate, + feature_indices=feature_indices, + teacher_cfg_scale=dmd2_teacher_cfg_scale, + student_grad_clip_norm=dmd2_student_grad_clip_norm, + ) + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/")) + self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model, remove_unnecessary_params=True) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + self.pipe.dit_teacher = copy.deepcopy(self.pipe.dit) + self.pipe.dit_fake_score = copy.deepcopy(self.pipe.dit) + self.pipe.dmd2_discriminator = None + if config.gan_loss_weight > 0: + self.pipe.dmd2_discriminator = FluxDMD2Discriminator(feature_indices=config.feature_indices) + self.pipe.dit_teacher.eval().requires_grad_(False) + self.pipe.dit_fake_score.train().requires_grad_(True) + if self.pipe.dmd2_discriminator is not None: + self.pipe.dmd2_discriminator.train().requires_grad_(True) + self.resume_from_checkpoint(resume_from_checkpoint, remove_prefix_in_ckpt) + + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.embedded_guidance = embedded_guidance + self.task = task + self.dmd2_student_model_name = "dit" + self.dmd2_teacher_model_name = "dit_teacher" + self.dmd2_fake_score_model_name = "dit_fake_score" + self.dmd2_discriminator_model_name = "dmd2_discriminator" + self.dmd2_model_fn_student = model_fn_flux2_dmd2 + self.dmd2_model_fn_teacher = model_fn_flux2_dmd2 + self.dmd2_model_fn_fake_score = model_fn_flux2_dmd2 + self.dmd2_loss = DMD2Loss(config) + self._dmd2_student_param_names = {name for name, param in self.pipe.dit.named_parameters() if param.requires_grad} + self._dmd2_fake_score_param_names = {name for name, _ in self.pipe.dit_fake_score.named_parameters()} + self.task_to_loss = { + "dmd2": lambda pipe, inputs_shared, inputs_posi, inputs_nega, iteration=None: self.dmd2_loss( + self, + (inputs_shared, inputs_posi, inputs_nega), + 0 if iteration is None else iteration, + ), + } + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "embedded_guidance": self.embedded_guidance, + "cfg_scale": self.dmd2_loss.config.teacher_cfg_scale, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None, iteration=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + return self.task_to_loss[self.task](self.pipe, *inputs, iteration=iteration) + + def export_trainable_state_dict(self, state_dict, remove_prefix=None): + return export_dmd2_trainable_state_dict(self, state_dict, remove_prefix=remove_prefix) + +def flux2_dmd2_parser(): + parser = argparse.ArgumentParser(description="Flux.2 DMD2 training script.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser = add_dmd2_config(parser) + parser.set_defaults(task="dmd2") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--embedded_guidance", type=float, default=1.0, help="Flux.2 embedded guidance value.") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") + return parser + + +if __name__ == "__main__": + parser = flux2_dmd2_parser() + args = parser.parse_args() + args.find_unused_parameters = True + + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = Flux2DMD2TrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + embedded_guidance=args.embedded_guidance, + resume_from_checkpoint=args.resume_from_checkpoint, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + dmd2_student_update_freq=args.dmd2_student_update_freq, + dmd2_student_sample_steps=args.dmd2_student_sample_steps, + dmd2_student_sample_type=args.dmd2_student_sample_type, + dmd2_student_schedule=args.dmd2_student_schedule, + dmd2_student_t_list=args.dmd2_student_t_list, + dmd2_matching_t_min=args.dmd2_matching_t_min, + dmd2_matching_t_max=args.dmd2_matching_t_max, + dmd2_matching_t_sampling=args.dmd2_matching_t_sampling, + dmd2_matching_t_mean=args.dmd2_matching_t_mean, + dmd2_matching_t_std=args.dmd2_matching_t_std, + dmd2_gan_loss_weight=args.dmd2_gan_loss_weight, + dmd2_gan_r1_reg_weight=args.dmd2_gan_r1_reg_weight, + dmd2_gan_r1_reg_alpha=args.dmd2_gan_r1_reg_alpha, + dmd2_fake_score_learning_rate=args.dmd2_fake_score_learning_rate, + dmd2_discriminator_learning_rate=args.dmd2_discriminator_learning_rate, + dmd2_feature_indices=args.dmd2_feature_indices, + dmd2_teacher_cfg_scale=args.dmd2_teacher_cfg_scale, + dmd2_student_grad_clip_norm=args.dmd2_student_grad_clip_norm, + task=args.task, + device="cpu" if (args.initialize_model_on_cpu or args.enable_model_cpu_offload) else accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + enable_tensorboard_log=args.enable_tensorboard_log, + enable_swanlab_log=args.enable_swanlab_log, + swanlab_project=args.swanlab_project, + enable_wandb_log=args.enable_wandb_log, + wandb_project=args.wandb_project, + ) + launch_dmd2_training_task(accelerator, dataset, model, model_logger, args=args) diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index 8e32eefc4..bd4938e65 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -157,4 +157,4 @@ def flux2_parser(): "direct_distill": launch_training_task, "direct_distill:train": launch_training_task, } - launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) \ No newline at end of file diff --git a/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py new file mode 100644 index 000000000..067f01901 --- /dev/null +++ b/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B-DMD2.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_dmd2/step-3000.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) + +prompt = "a dog" +image = pipe(prompt=prompt, seed=0, num_inference_steps=4, height=512, width=512) +image.save("FLUX.2-klein-base-4B_dmd2.jpg") \ No newline at end of file