Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 89 additions & 20 deletions invokeai/app/invocations/qwen_image_denoise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
from typing import Callable, ClassVar, Iterator, Optional, Tuple

import torch
import torchvision.transforms as tv_transforms
Expand Down Expand Up @@ -176,6 +177,72 @@ def _unpack_latents(latents: torch.Tensor, height: int, width: int) -> torch.Ten
latents = latents.reshape(batch_size, channels // 4, h, w)
return latents

@staticmethod
def _align_ref_latent_dims(rh: int, rw: int) -> tuple[int, int]:
"""Trim reference latent spatial dims to even values for 2x2 packing.

Raises ValueError if the aligned dims would be < 2 (i.e., the reference
latent is too small to produce any valid tokens).
"""
rh_aligned = rh - (rh % 2)
rw_aligned = rw - (rw % 2)
if rh_aligned < 2 or rw_aligned < 2:
raise ValueError(
f"Reference latent spatial dims must be >= 2 after even alignment; "
f"got ({rh_aligned}, {rw_aligned}) from input shape ({rh}, {rw}). "
"Ensure the reference image is at least 16 pixels in each dimension."
)
return rh_aligned, rw_aligned

@staticmethod
def _build_img_shapes(
latent_height: int,
latent_width: int,
ref_latent_height: int | None = None,
ref_latent_width: int | None = None,
) -> list[list[tuple[int, int, int]]]:
"""Build the img_shapes argument for the transformer.

The reference segment (if present) must use its own dims so QwenEmbedRope's
spatial frequencies position ref tokens distinctly from noisy tokens —
otherwise reference content bleeds into the generation as a ghost.
"""
shapes: list[tuple[int, int, int]] = [(1, latent_height // 2, latent_width // 2)]
if ref_latent_height is not None and ref_latent_width is not None:
shapes.append((1, ref_latent_height // 2, ref_latent_width // 2))
return [shapes]

# diffusers' QwenImageEdit(Plus)Pipeline VAE_IMAGE_SIZE = 1024 * 1024 pixels;
# ref images are resized to this area (preserving aspect, snapped to multiples
# of 32) before VAE encoding. We mirror this clamp in latent space so direct
# backend callers — whose i2l may not pass explicit width/height — don't feed
# the transformer an out-of-distribution reference sequence length (which
# also causes a VRAM spike for large inputs).
_REF_TARGET_PIXEL_AREA: ClassVar[int] = 1024 * 1024
_VAE_SCALE_FACTOR: ClassVar[int] = 8

@classmethod
def _maybe_clamp_ref_latent_size(cls, ref_latents: torch.Tensor) -> torch.Tensor:
"""Bilinear-downscale the reference latent if it exceeds diffusers'
VAE_IMAGE_SIZE budget.

Returns the latent unchanged if it's already within budget.
"""
_, _, rh, rw = ref_latents.shape
target_cells = cls._REF_TARGET_PIXEL_AREA // (cls._VAE_SCALE_FACTOR**2)
if rh * rw <= target_cells:
return ref_latents
aspect = rw / rh
target_w_px = math.sqrt(cls._REF_TARGET_PIXEL_AREA * aspect)
target_h_px = target_w_px / aspect
target_w_px = max(32, round(target_w_px / 32) * 32)
target_h_px = max(32, round(target_h_px / 32) * 32)
target_rh = target_h_px // cls._VAE_SCALE_FACTOR
target_rw = target_w_px // cls._VAE_SCALE_FACTOR
return torch.nn.functional.interpolate(
ref_latents, size=(target_rh, target_rw), mode="bilinear", antialias=False
)

def _run_diffusion(self, context: InvocationContext):
inference_dtype = torch.bfloat16
device = TorchDevice.choose_torch_device()
Expand Down Expand Up @@ -332,35 +399,37 @@ def _run_diffusion(self, context: InvocationContext):
use_ref_latents = has_zero_cond_t

ref_latents_packed = None
ref_latent_height = latent_height
ref_latent_width = latent_width
if use_ref_latents:
if ref_latents is not None:
_, ref_ch, rh, rw = ref_latents.shape
if rh != latent_height or rw != latent_width:
ref_latents = torch.nn.functional.interpolate(
ref_latents, size=(latent_height, latent_width), mode="bilinear"
)
# Defense-in-depth: backend callers (direct API, older graph JSON)
# may wire qwen_image_i2l without explicit width/height, producing
# a native-resolution reference latent. Clamp here so the
# transformer always sees an in-distribution sequence length.
ref_latents = self._maybe_clamp_ref_latent_size(ref_latents)
_, _, rh, rw = ref_latents.shape
ref_latent_height, ref_latent_width = self._align_ref_latent_dims(rh, rw)
if ref_latent_height != rh or ref_latent_width != rw:
ref_latents = ref_latents[..., :ref_latent_height, :ref_latent_width]
else:
# No reference image provided — use zeros so the model still gets the
# expected sequence layout.
ref_latents = torch.zeros(
1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype
)
ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width)

# img_shapes tells the transformer the spatial layout of patches.
ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, ref_latent_height, ref_latent_width)

# img_shapes tells the transformer the spatial layout of patches. The reference
# segment must use the reference latent's own dimensions so RoPE positions it
# distinctly from the noisy latent — otherwise the two segments share spatial
# positional encoding and the model can't disentangle them, producing a
# ghost/doubling artifact across the whole frame. Matches diffusers'
# QwenImageEditPipeline / QwenImageEditPlusPipeline.
if use_ref_latents:
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
(1, latent_height // 2, latent_width // 2),
]
]
img_shapes = self._build_img_shapes(latent_height, latent_width, ref_latent_height, ref_latent_width)
else:
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
]
]
img_shapes = self._build_img_shapes(latent_height, latent_width)

# Prepare inpaint extension (operates in 4D space, so unpack/repack around it)
inpaint_mask = self._prep_inpaint_mask(context, noise) # noise has the right 4D shape
Expand Down
4 changes: 3 additions & 1 deletion invokeai/app/invocations/qwen_image_image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
if self.width is not None and self.height is not None:
image = image.convert("RGB").resize((self.width, self.height), resample=PILImage.LANCZOS)

image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
# multiple_of=16 ensures the post-VAE latents (vae_scale_factor=8) have even
# spatial dims, which the transformer's 2x2 patch packing requires.
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"), multiple_of=16)
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ vi.mock('services/api/types', async () => {
};
});

import { buildQwenImageGraph, isQwenImageEditModel, shouldUseCfg } from './buildQwenImageGraph';
import {
buildQwenImageGraph,
calculateQwenImageEditRefDimensions,
isQwenImageEditModel,
shouldUseCfg,
} from './buildQwenImageGraph';

describe('isQwenImageEditModel', () => {
afterEach(() => {
Expand Down Expand Up @@ -415,3 +420,80 @@ describe('buildQwenImageGraph', () => {
expect(hasReferenceLatentsEdge).toBe(false);
});
});

describe('calculateQwenImageEditRefDimensions', () => {
// Cross-checked against diffusers' calculate_dimensions(1024*1024, ratio)
// (see pipeline_qwenimage_edit.py / pipeline_qwenimage_edit_plus.py).
it('produces ~1024² area for a square input', () => {
const result = calculateQwenImageEditRefDimensions(512, 512);
expect(result).toEqual({ width: 1024, height: 1024 });
});

it('preserves aspect ratio for landscape inputs', () => {
expect(calculateQwenImageEditRefDimensions(1600, 1200)).toEqual({ width: 1184, height: 896 });
expect(calculateQwenImageEditRefDimensions(1920, 1080)).toEqual({ width: 1376, height: 768 });
});

it('preserves aspect ratio for portrait inputs', () => {
expect(calculateQwenImageEditRefDimensions(1200, 1600)).toEqual({ width: 896, height: 1184 });
expect(calculateQwenImageEditRefDimensions(1080, 1920)).toEqual({ width: 768, height: 1376 });
});

it('snaps dimensions to multiples of 32', () => {
const { width, height } = calculateQwenImageEditRefDimensions(1600, 1200);
expect(width % 32).toBe(0);
expect(height % 32).toBe(0);
});

it('clamps to a minimum of 32 for extreme aspect ratios', () => {
// 50000x100 has aspect ratio 500:1 — height would round to 0 without the clamp.
const { width, height } = calculateQwenImageEditRefDimensions(50000, 100);
expect(height).toBeGreaterThanOrEqual(32);
expect(width).toBeGreaterThanOrEqual(32);
expect(width % 32).toBe(0);
expect(height % 32).toBe(0);
});

it('passes computed dims as width/height to the reference i2l node', async () => {
const { selectMainModelConfig } = await import('features/controlLayers/store/paramsSlice');
const editModel = { ...model, variant: 'edit' };
vi.mocked(selectMainModelConfig).mockReturnValue(editModel as never);

const { fetchModelConfigWithTypeGuard } = await import('features/metadata/util/modelFetchingHelpers');
vi.mocked(fetchModelConfigWithTypeGuard).mockResolvedValue(editModel as never);

const { selectRefImagesSlice } = await import('features/controlLayers/store/refImagesSlice');
vi.mocked(selectRefImagesSlice).mockReturnValue({
entities: [
{
id: 'ref-image-1',
isEnabled: true,
config: {
type: 'qwen_image_reference_image',
image: { original: { image: { image_name: 'ref.png', width: 1600, height: 1200 } } },
},
},
],
} as never);

const { g } = await buildQwenImageGraph({
generationMode: 'txt2img',
manager: null,
state: {
system: { shouldUseNSFWChecker: false, shouldUseWatermarker: false },
} as never,
});

const graph = g.getGraph();
const refI2lNodeId = Object.keys(graph.nodes).find((id) => id.startsWith('qwen_ref_i2l:'));
expect(refI2lNodeId).toBeDefined();
const refI2lNode = graph.nodes[refI2lNodeId!] as { width?: number; height?: number };
expect(refI2lNode.width).toBe(1184);
expect(refI2lNode.height).toBe(896);

// Restore mocks
vi.mocked(selectMainModelConfig).mockReturnValue(model as never);
vi.mocked(fetchModelConfigWithTypeGuard).mockResolvedValue(model as never);
vi.mocked(selectRefImagesSlice).mockReturnValue(refImagesSlice as never);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ export const shouldUseCfg = (cfgScale: number | number[]): boolean => {
return cfgScale.some((value) => value > 1);
};

/**
* Compute the target dimensions for the VAE-encoded reference image, matching
* diffusers' `calculate_dimensions(VAE_IMAGE_SIZE=1024*1024, aspect_ratio)` used
* by QwenImageEditPipeline / QwenImageEditPlusPipeline. The reference is resized
* so its area is ~1024² while preserving aspect ratio, with each dimension
* snapped to a multiple of 32 (the model was trained at this scale; feeding it a
* much larger reference produces a sequence length it was not trained on).
*/
const QWEN_IMAGE_EDIT_REF_TARGET_AREA = 1024 * 1024;
export const calculateQwenImageEditRefDimensions = (
width: number,
height: number
): { width: number; height: number } => {
const ratio = width / height;
let w = Math.sqrt(QWEN_IMAGE_EDIT_REF_TARGET_AREA * ratio);
let h = w / ratio;
w = Math.max(32, Math.round(w / 32) * 32);
h = Math.max(32, Math.round(h / 32) * 32);
return { width: w, height: h };
};

export const buildQwenImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state, manager } = arg;

Expand Down Expand Up @@ -175,15 +196,18 @@ export const buildQwenImageGraph = async (arg: GraphBuilderArg): Promise<GraphBu
// Also VAE-encode the first reference image as latents for the denoising transformer.
// The transformer expects [noisy_patches ; ref_patches] in its sequence.
const firstConfig = validRefImageConfigs[0]!;
const firstImgField = zImageField.parse(
firstConfig.config.image?.crop?.image ?? firstConfig.config.image?.original.image
);
// Don't force-resize the reference image to the output dimensions — that would
// distort the aspect ratio when they differ. The I2L encodes at the image's
// native size; the denoise node handles dimension mismatches via interpolation.
const firstImage = firstConfig.config.image?.crop?.image ?? firstConfig.config.image?.original.image;
const firstImgField = zImageField.parse(firstImage);
// Resize the reference image to ~1024² area preserving aspect ratio, matching the
// diffusers QwenImageEdit(Plus)Pipeline's VAE_IMAGE_SIZE. The denoise node uses
// the reference latent's own dimensions for RoPE, so the ref segment is encoded
// at the resolution the model was trained on rather than the source image's
// native size.
const refDims = firstImage ? calculateQwenImageEditRefDimensions(firstImage.width, firstImage.height) : undefined;
const refI2l = g.addNode({
type: 'qwen_image_i2l',
id: getPrefixedId('qwen_ref_i2l'),
...(refDims ? { width: refDims.width, height: refDims.height } : {}),
});
const refImageNode = g.addNode({
type: 'image',
Expand Down
Loading
Loading