Skip to content
Open
52 changes: 52 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,55 @@ def test_dreambooth_lora_with_metadata(self):
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)

def test_dreambooth_lora_flux2_aspect_ratio_buckets(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--aspect_ratio_buckets 64,64;64,128
--bucket_no_upscale
--cache_latents
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
self.assertTrue(all("lora" in k for k in lora_state_dict.keys()))
self.assertTrue(all(key.startswith("transformer") for key in lora_state_dict.keys()))

def test_dreambooth_lora_flux2_caption_dropout(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--caption_dropout 1.0
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
self.assertTrue(all("lora" in k for k in lora_state_dict.keys()))
48 changes: 48 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,51 @@ def test_dreambooth_lora_with_metadata(self):
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)

def test_dreambooth_lora_qwen_aspect_ratio_buckets(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--aspect_ratio_buckets 64,64;64,128
--bucket_no_upscale
--cache_latents
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
self.assertTrue(all("lora" in k for k in lora_state_dict.keys()))
self.assertTrue(all(key.startswith("transformer") for key in lora_state_dict.keys()))

def test_dreambooth_lora_qwen_caption_dropout(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--caption_dropout 1.0
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
self.assertTrue(all("lora" in k for k in lora_state_dict.keys()))
58 changes: 26 additions & 32 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ def parse_args(input_args=None):
"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
),
)
parser.add_argument(
"--bucket_no_upscale",
action="store_true",
help="If set, images smaller than their aspect-ratio bucket are padded instead of upscaled.",
)
parser.add_argument(
"--center_crop",
default=False,
Expand Down Expand Up @@ -890,15 +895,6 @@ def __init__(
else:
self.class_data_root = None

self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length

Expand All @@ -924,37 +920,35 @@ def __getitem__(self, index):

if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
# Match the class image to the paired instance image's bucket so they can be stacked into one batch.
example["class_images"] = self.train_transform(
class_image, size=self.buckets[bucket_idx], center_crop=self.center_crop
)
example["class_prompt"] = self.class_prompt

return example

def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False):
# 1. Resize (deterministic)
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
image = resize(image)

# 2. Crop: either center or SAME random crop
def train_transform(self, image, size, center_crop=False, random_flip=False):
# Resize preserving aspect ratio so the image covers the bucket, then crop to the bucket size.
target_height, target_width = size
width, height = image.size
scale = max(target_height / height, target_width / width)
if args.bucket_no_upscale:
scale = min(scale, 1.0)
new_height, new_width = round(height * scale), round(width * scale)
image = TF.resize(image, [new_height, new_width], interpolation=transforms.InterpolationMode.BILINEAR)
# Pad to the bucket when no-upscale leaves the image smaller, so batched samples share a shape.
pad_w, pad_h = max(0, target_width - new_width), max(0, target_height - new_height)
if pad_w or pad_h:
image = TF.pad(image, [pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2])
if center_crop:
crop = transforms.CenterCrop(size)
image = crop(image)
image = TF.center_crop(image, size)
else:
# get_params returns (i, j, h, w)
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
image = TF.crop(image, i, j, h, w)

# 3. Random horizontal flip with the SAME coin flip
if random_flip:
do_flip = random.random() < 0.5
if do_flip:
image = TF.hflip(image)

# 4. ToTensor + Normalize (deterministic)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize([0.5], [0.5])
image = normalize(to_tensor(image))

return image
if random_flip and random.random() < 0.5:
image = TF.hflip(image)
return TF.normalize(TF.to_tensor(image), [0.5], [0.5])


def collate_fn(examples, with_prior_preservation=False):
Expand Down
64 changes: 30 additions & 34 deletions examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,11 @@ def parse_args(input_args=None):
"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
),
)
parser.add_argument(
"--bucket_no_upscale",
action="store_true",
help="If set, images smaller than their aspect-ratio bucket are padded instead of upscaled.",
)
parser.add_argument(
"--center_crop",
default=False,
Expand Down Expand Up @@ -884,15 +889,6 @@ def __init__(
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images

self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length

Expand All @@ -918,40 +914,40 @@ def __getitem__(self, index):
return example

def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
# 1. Resize (deterministic)
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
image = resize(image)
# Resize preserving aspect ratio so the image covers the bucket, then crop to the bucket size.
# The same geometry is applied to the conditioning image so the pair stays aligned.
target_height, target_width = size
width, height = image.size
scale = max(target_height / height, target_width / width)
if args.bucket_no_upscale:
scale = min(scale, 1.0)
new_size = [round(height * scale), round(width * scale)]
# Pad to the bucket when no-upscale leaves the image smaller, so batched samples share a shape.
pad_w, pad_h = max(0, target_width - new_size[1]), max(0, target_height - new_size[0])
padding = [pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2]
image = TF.resize(image, new_size, interpolation=transforms.InterpolationMode.BILINEAR)
if pad_w or pad_h:
image = TF.pad(image, padding)
if dest_image is not None:
dest_image = resize(dest_image)

# 2. Crop: either center or SAME random crop
dest_image = TF.resize(dest_image, new_size, interpolation=transforms.InterpolationMode.BILINEAR)
if pad_w or pad_h:
dest_image = TF.pad(dest_image, padding)
if center_crop:
crop = transforms.CenterCrop(size)
image = crop(image)
image = TF.center_crop(image, size)
if dest_image is not None:
dest_image = crop(dest_image)
dest_image = TF.center_crop(dest_image, size)
else:
# get_params returns (i, j, h, w)
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
image = TF.crop(image, i, j, h, w)
if dest_image is not None:
dest_image = TF.crop(dest_image, i, j, h, w)

# 3. Random horizontal flip with the SAME coin flip
if random_flip:
do_flip = random.random() < 0.5
if do_flip:
image = TF.hflip(image)
if dest_image is not None:
dest_image = TF.hflip(dest_image)

# 4. ToTensor + Normalize (deterministic)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize([0.5], [0.5])
image = normalize(to_tensor(image))
if random_flip and random.random() < 0.5:
image = TF.hflip(image)
if dest_image is not None:
dest_image = TF.hflip(dest_image)
image = TF.normalize(TF.to_tensor(image), [0.5], [0.5])
if dest_image is not None:
dest_image = normalize(to_tensor(dest_image))

dest_image = TF.normalize(TF.to_tensor(dest_image), [0.5], [0.5])
return (image, dest_image) if dest_image is not None else (image, None)


Expand Down
Loading
Loading