|
- """Fine-tuning script for Stable Diffusion for text2image with support for ControlLoRA."""
- """Code refer to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py"""
-
- from diffusers import utils
- from diffusers.utils import deprecation_utils
- from diffusers.models import cross_attention
- utils.deprecate = lambda *arg, **kwargs: None
- deprecation_utils.deprecate = lambda *arg, **kwargs: None
- cross_attention.deprecate = lambda *arg, **kwargs: None
-
- import argparse
- import logging
- import math
- import os
- import random
- from pathlib import Path
- from typing import Optional
-
- import datasets
- import numpy as np
- import torch
- import torch.nn.functional as F
- import torch.utils.checkpoint
- import transformers
- from accelerate import Accelerator
- from accelerate.logging import get_logger
- from accelerate.utils import set_seed
- from datasets import load_dataset
- from huggingface_hub import HfFolder, Repository, create_repo, whoami
- from torchvision import transforms
- from tqdm.auto import tqdm
- from transformers import CLIPTextModel, CLIPTokenizer
-
- import diffusers
- from diffusers import (
- AutoencoderKL,
- DDPMScheduler,
- DPMSolverMultistepScheduler,
- DiffusionPipeline,
- UNet2DConditionModel)
- from diffusers.optimization import get_scheduler
- from diffusers.utils import check_min_version, is_wandb_available
- from diffusers.utils.import_utils import is_xformers_available
- from models import ControlLoRA
- from process import DatasetBase as dataset_cls
-
-
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
- check_min_version("0.13.0.dev0")
-
- logger = get_logger(__name__, log_level="INFO")
-
-
- def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, repo_folder=None):
- img_str = ""
- for i, image in enumerate(images):
- image.save(os.path.join(repo_folder, f"image_{i}.png"))
- img_str += f"![img_{i}](./image_{i}.png)\n"
-
- yaml = f"""
- ---
- license: creativeml-openrail-m
- base_model: {base_model}
- tags:
- - stable-diffusion
- - stable-diffusion-diffusers
- - text-to-image
- - diffusers
- - lora
- - controlnet
- - control-lora
- inference: true
- ---
- """
- model_card = f"""
- # ControlLoRA text2image fine-tuning - {repo_name}
- These are ControlLoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
- {img_str}
- """
- with open(os.path.join(repo_folder, "README.md"), "w") as f:
- f.write(yaml + model_card)
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
- parser.add_argument(
- "--pretrained_model_name_or_path",
- type=str,
- default=None,
- required=True,
- help="Path to pretrained model or model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--revision",
- type=str,
- default=None,
- required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--dataset_name",
- type=str,
- default=None,
- help=(
- "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
- " or to a folder containing files that 🤗 Datasets can understand."
- ),
- )
- parser.add_argument(
- "--dataset_config_name",
- type=str,
- default=None,
- help="The config of the Dataset, leave as None if there's only one config.",
- )
- parser.add_argument(
- "--train_data_dir",
- type=str,
- default=None,
- help=(
- "A folder containing the training data. Folder contents must follow the structure described in"
- " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
- " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
- ),
- )
- parser.add_argument(
- "--image_column", type=str, default="image", help="The column of the dataset containing an image."
- )
- parser.add_argument(
- "--guide_column", type=str, default="guide", help="The column of the dataset containing a guide image."
- )
- parser.add_argument(
- "--caption_column",
- type=str,
- default="text",
- help="The column of the dataset containing a caption or a list of captions.",
- )
- parser.add_argument(
- "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
- )
- parser.add_argument(
- "--num_validation_images",
- type=int,
- default=16,
- help="Number of images that should be generated during validation with `validation_prompt`.",
- )
- parser.add_argument(
- "--validation_epochs",
- type=int,
- default=1,
- help=(
- "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
- " `args.validation_prompt` multiple times: `args.num_validation_images`."
- ),
- )
- parser.add_argument(
- "--max_train_samples",
- type=int,
- default=None,
- help=(
- "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
- ),
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- default="sd-fill50k-model-control-lora",
- help="The output directory where the model predictions and checkpoints will be written.",
- )
- parser.add_argument(
- "--cache_dir",
- type=str,
- default=None,
- help="The directory where the downloaded models and datasets will be stored.",
- )
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
- parser.add_argument(
- "--resolution",
- type=int,
- default=512,
- help=(
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
- " resolution"
- ),
- )
- parser.add_argument(
- "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
- )
- parser.add_argument("--num_train_epochs", type=int, default=100)
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=None,
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
- )
- parser.add_argument(
- "--gradient_accumulation_steps",
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.",
- )
- parser.add_argument(
- "--gradient_checkpointing",
- action="store_true",
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=1e-4,
- help="Initial learning rate (after the potential warmup period) to use.",
- )
- parser.add_argument(
- "--scale_lr",
- action="store_true",
- default=False,
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
- )
- parser.add_argument(
- "--lr_scheduler",
- type=str,
- default="constant",
- help=(
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
- ' "constant", "constant_with_warmup"]'
- ),
- )
- parser.add_argument(
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
- )
- parser.add_argument(
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
- )
- parser.add_argument(
- "--allow_tf32",
- action="store_true",
- help=(
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
- ),
- )
- parser.add_argument(
- "--dataloader_num_workers",
- type=int,
- default=0,
- help=(
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
- ),
- )
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
- parser.add_argument(
- "--hub_model_id",
- type=str,
- default=None,
- help="The name of the repository to keep in sync with the local `output_dir`.",
- )
- parser.add_argument(
- "--logging_dir",
- type=str,
- default="logs",
- help=(
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
- ),
- )
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default=None,
- choices=["no", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
- ),
- )
- parser.add_argument(
- "--report_to",
- type=str,
- default="tensorboard",
- help=(
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
- ),
- )
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
- parser.add_argument(
- "--checkpointing_steps",
- type=int,
- default=500,
- help=(
- "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
- " training using `--resume_from_checkpoint`."
- ),
- )
- parser.add_argument(
- "--resume_from_checkpoint",
- type=str,
- default=None,
- help=(
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
- ),
- )
- parser.add_argument(
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
- )
- parser.add_argument("--control_lora_config", type=str, required=True, help="Config file of ControlLora")
-
- args = parser.parse_args()
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
- if env_local_rank != -1 and env_local_rank != args.local_rank:
- args.local_rank = env_local_rank
-
- # Sanity checks
- if args.dataset_name is None and args.train_data_dir is None:
- raise ValueError("Need either a dataset name or a training folder.")
-
- return args
-
-
- def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
- if token is None:
- token = HfFolder.get_token()
- if organization is None:
- username = whoami(token)["name"]
- return f"{username}/{model_id}"
- else:
- return f"{organization}/{model_id}"
-
-
- DATASET_NAME_MAPPING = {
- "HighCWu/fill50k": ("image", "guide", "text"),
- }
-
-
- def main():
- args = parse_args()
- logging_dir = os.path.join(args.output_dir, args.logging_dir)
-
- accelerator = Accelerator(
- gradient_accumulation_steps=args.gradient_accumulation_steps,
- mixed_precision=args.mixed_precision,
- log_with=args.report_to,
- logging_dir=logging_dir,
- )
- if args.report_to == "wandb":
- if not is_wandb_available():
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
- import wandb
-
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO,
- )
- logger.info(accelerator.state, main_process_only=False)
- if accelerator.is_local_main_process:
- datasets.utils.logging.set_verbosity_warning()
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- datasets.utils.logging.set_verbosity_error()
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
-
- # If passed along, set the training seed now.
- if args.seed is not None:
- set_seed(args.seed)
-
- # Handle the repository creation
- if accelerator.is_main_process:
- if args.push_to_hub:
- if args.hub_model_id is None:
- repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
- else:
- repo_name = args.hub_model_id
- repo_name = create_repo(repo_name, exist_ok=True)
- repo = Repository(args.output_dir, clone_from=repo_name)
-
- with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
- if "step_*" not in gitignore:
- gitignore.write("step_*\n")
- if "epoch_*" not in gitignore:
- gitignore.write("epoch_*\n")
- elif args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
- # Load scheduler, tokenizer and models.
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
- tokenizer = CLIPTokenizer.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
- )
- text_encoder = CLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
- )
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
- unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
- )
-
- n_ch = len(unet.config.block_out_channels)
- control_ids = [i for i in range(n_ch)]
- cross_attention_dims = {i: [] for i in range(n_ch)}
- for name in unet.attn_processors.keys():
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if name.startswith("mid_block"):
- control_id = control_ids[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- control_id = list(reversed(control_ids))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- control_id = control_ids[block_id]
- cross_attention_dims[control_id].append(cross_attention_dim)
- cross_attention_dims = tuple([cross_attention_dims[control_id] for control_id in control_ids])
-
- control_lora = ControlLoRA.from_config(args.control_lora_config)
-
- # freeze parameters of models to save more memory
- unet.requires_grad_(False)
- vae.requires_grad_(False)
-
- text_encoder.requires_grad_(False)
-
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
- # as these models are only used for inference, keeping weights in full precision is not required.
- weight_dtype = torch.float32
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- # Move unet, vae and text_encoder to device and cast to weight_dtype
- unet.to(accelerator.device, dtype=weight_dtype)
- vae.to(accelerator.device, dtype=weight_dtype)
- text_encoder.to(accelerator.device, dtype=weight_dtype)
- control_lora.to(accelerator.device) # control_lora.to(accelerator.device), dtype=weight_dtype)
-
- if args.enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- unet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError("xformers is not available. Make sure it is installed correctly")
-
- # now we will add new LoRA weights to the attention layers
- # It's important to realize here how many attention weights will be added and of which sizes
- # The sizes of the attention layers consist only of two different variables:
- # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
- # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
-
- # Let's first see how many attention processors we will have to set.
- # For Stable Diffusion, it should be equal to:
- # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
- # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
- # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
- # => 32 layers
-
- # Set correct lora layers
- lora_attn_procs = {}
- lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
- for name in unet.attn_processors.keys():
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if name.startswith("mid_block"):
- control_id = control_ids[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- control_id = list(reversed(control_ids))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- control_id = control_ids[block_id]
-
- lora_layers = lora_layers_list[control_id]
- if len(lora_layers) != 0:
- lora_layer = lora_layers.pop(0)
- lora_attn_procs[name] = lora_layer
-
- unet.set_attn_processor(lora_attn_procs)
-
- # Enable TF32 for faster training on Ampere GPUs,
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if args.allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
-
- if args.scale_lr:
- args.learning_rate = (
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
- )
-
- # Initialize the optimizer
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
- )
-
- optimizer_cls = bnb.optim.AdamW8bit
- else:
- optimizer_cls = torch.optim.AdamW
-
- optimizer = optimizer_cls(
- control_lora.parameters(),
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- )
-
- # Preprocessing the datasets.
- # We need to tokenize input captions and transform the images.
- def tokenize_captions(examples, is_train=True):
- captions = []
- for caption in examples[caption_column]:
- if isinstance(caption, str):
- captions.append(caption)
- elif isinstance(caption, (list, np.ndarray)):
- # take a random caption if there are multiple
- captions.append(random.choice(caption) if is_train else caption[0])
- else:
- raise ValueError(
- f"Caption column `{caption_column}` should contain either strings or lists of strings."
- )
- inputs = tokenizer(
- captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
- )
- return inputs.input_ids
-
- # Get the datasets: you can either provide your own training and evaluation files (see below)
- # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
-
- # In distributed training, the load_dataset function guarantees that only one local process can concurrently
- # download the dataset.
- global dataset_cls
- use_custom_dataset = False
- if args.dataset_name.startswith('process/'):
- # Use custom dataset define in process
- use_custom_dataset = True
- dataset_cls = dataset_cls.from_name(args.dataset_name)
- dataset = dataset_cls(tokenize_captions, resolution=args.resolution, use_crop=True)
- elif args.dataset_name is not None:
- # Downloading and loading a dataset from the hub.
- dataset = load_dataset(
- args.dataset_name,
- args.dataset_config_name,
- cache_dir=args.cache_dir,
- )
- else:
- data_files = {}
- if args.train_data_dir is not None:
- data_files["train"] = os.path.join(args.train_data_dir, "**")
- dataset = load_dataset(
- "imagefolder",
- data_files=data_files,
- cache_dir=args.cache_dir,
- )
- # See more about loading custom images at
- # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
-
- if use_custom_dataset:
- collate_fn = None
- train_dataset = dataset
- caption_column = args.caption_column
- else:
- # Preprocessing the datasets.
- # We need to tokenize inputs and targets.
- column_names = dataset["train"].column_names
-
- # 6. Get the column names for input/target.
- dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
- if args.image_column is None:
- image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
- else:
- image_column = args.image_column
- if image_column not in column_names:
- raise ValueError(
- f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
- )
- if args.guide_column is None:
- guide_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
- else:
- guide_column = args.guide_column
- if guide_column not in column_names:
- raise ValueError(
- f"--guide_column' value '{args.guide_column}' needs to be one of: {', '.join(column_names)}"
- )
- if args.caption_column is None:
- caption_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
- else:
- caption_column = args.caption_column
- if caption_column not in column_names:
- raise ValueError(
- f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
- )
-
- # Preprocessing the datasets.
- train_transforms = transforms.Compose(
- [
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
- transforms.CenterCrop(args.resolution), # TODO we need to random crop with image and guide at the same time
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
-
- def preprocess_train(examples):
- images, guides = [], []
- for image, guide in zip(examples[image_column], examples[guide_column]):
- image, guide = image.convert("RGB"), guide.convert("RGB")
- image, guide = train_transforms(image), train_transforms(guide)
- c, h, w = image.shape
- y1, x1 = 0, 0
- if h != args.resolution:
- y1 = torch.randint(0, h - args.resolution, (1, )).item()
- elif w != args.resolution:
- x1 = torch.randint(0, w - args.resolution, (1, )).item()
- y2, x2 = y1 + args.resolution, x1 + args.resolution
- image = image[:,y1:y2,x1:x2]
- guide = guide[:,y1:y2,x1:x2]
- images.append(image)
- guides.append(guide)
-
- examples["pixel_values"] = images
- examples["guide_values"] = guides
- examples["input_ids"] = tokenize_captions(examples)
- return examples
-
- with accelerator.main_process_first():
- if args.max_train_samples is not None:
- dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
- # Set the training transforms
- train_dataset = dataset["train"].with_transform(preprocess_train)
-
- def collate_fn(examples):
- pixel_values = torch.stack([example["pixel_values"] for example in examples])
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
- guide_values = torch.stack([example["guide_values"] for example in examples])
- guide_values = guide_values.to(memory_format=torch.contiguous_format).float()
- input_ids = torch.stack([example["input_ids"] for example in examples])
- return {"pixel_values": pixel_values, "guide_values": guide_values, "input_ids": input_ids}
-
- # DataLoaders creation:
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset,
- shuffle=True,
- collate_fn=collate_fn,
- batch_size=args.train_batch_size,
- num_workers=args.dataloader_num_workers,
- )
- val_dataloader = torch.utils.data.DataLoader(
- train_dataset,
- shuffle=False,
- collate_fn=collate_fn,
- batch_size=1,
- num_workers=0,
- )
- val_iter = iter(val_dataloader)
-
- # Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
-
- lr_scheduler = get_scheduler(
- args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
- )
-
- # Prepare everything with our `accelerator`.
- control_lora, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- control_lora, optimizer, train_dataloader, lr_scheduler
- )
-
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- # Afterwards we recalculate our number of training epochs
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
-
- # Train!
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
-
- logger.info("***** Running testing *****")
- logger.info(f" Num examples = {len(train_dataset)}")
- logger.info(f" Num Epochs = {args.num_train_epochs}")
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {args.max_train_steps}")
-
- # Potentially load in the weights and states from a previous save
- if args.resume_from_checkpoint:
- if args.resume_from_checkpoint != "latest":
- path = os.path.basename(args.resume_from_checkpoint)
- else:
- # Get the most recent checkpoint
- dirs = os.listdir(args.output_dir)
- dirs = [d for d in dirs if d.startswith("checkpoint")]
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
- path = dirs[-1] if len(dirs) > 0 else None
-
- if path is None:
- raise ValueError(
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist."
- )
- else:
- accelerator.print(f"Resuming from checkpoint {path}")
- accelerator.load_state(os.path.join(args.output_dir, path))
-
- # Save the lora layers
- accelerator.wait_for_everyone()
- if accelerator.is_main_process:
- unet = unet.to(torch.float32)
- # unet.save_attn_procs(args.output_dir)
- control_lora.save_config(args.output_dir)
- control_lora.save_pretrained(args.output_dir, safe_serialization=False)
- control_lora.save_pretrained(args.output_dir, safe_serialization=True)
-
- if args.push_to_hub:
- save_model_card(
- repo_name,
- images=images,
- base_model=args.pretrained_model_name_or_path,
- dataset_name=args.dataset_name,
- repo_folder=args.output_dir,
- )
- repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
-
- # Final inference
- # Load previous pipeline
- pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype, safety_checker=None
- )
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
- pipeline = pipeline.to(accelerator.device)
-
- # load attention processors
- lora_attn_procs = {}
- lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
- for name in pipeline.unet.attn_processors.keys():
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if name.startswith("mid_block"):
- control_id = control_ids[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- control_id = list(reversed(control_ids))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- control_id = control_ids[block_id]
-
- lora_layers = lora_layers_list[control_id]
- if len(lora_layers) != 0:
- lora_layer = lora_layers.pop(0)
- lora_attn_procs[name] = lora_layer
-
- pipeline.unet.set_attn_processor(lora_attn_procs)
-
- # run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
- images = []
- output_dir = os.path.basename(args.output_dir)
- os.makedirs(os.path.join("samples", output_dir), exist_ok=True)
- for i in range(args.num_validation_images):
- with torch.no_grad():
- try:
- batch = next(val_iter)
- except:
- val_iter = iter(val_dataloader)
- batch = next(val_iter)
- target = batch["pixel_values"].to(dtype=weight_dtype)
- guide = batch["guide_values"].to(accelerator.device)
- _ = control_lora(guide).control_states
- image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
- image = dataset_cls.cat_input(image, target, guide)
- image.save(os.path.join("samples", output_dir, f"{i}.png"))
-
-
- if __name__ == "__main__":
- main()
|