|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from typing import List, Tuple, Union
- from dataclasses import dataclass
- from diffusers.utils.outputs import BaseOutput
- from diffusers.configuration_utils import ConfigMixin, register_to_config
- from diffusers.models.modeling_utils import ModelMixin
- from diffusers.models.unet_2d_blocks import get_down_block as get_down_block_default
- from diffusers.models.resnet import Mish, Upsample2D, Downsample2D, upsample_2d, downsample_2d, partial
- from diffusers.models.cross_attention import CrossAttention, LoRALinearLayer # , LoRACrossAttnProcessor
-
-
- def get_down_block(
- down_block_type,
- num_layers,
- in_channels,
- out_channels,
- temb_channels,
- add_downsample,
- resnet_eps,
- resnet_act_fn,
- attn_num_head_channels,
- resnet_groups=None,
- cross_attention_dim=None,
- downsample_padding=None,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- resnet_time_scale_shift="default",
- resnet_kernel_size=3,
- ):
- down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
- if down_block_type == "SimpleDownEncoderBlock2D":
- return SimpleDownEncoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_downsample=add_downsample,
- convnet_eps=resnet_eps,
- convnet_act_fn=resnet_act_fn,
- convnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- convnet_time_scale_shift=resnet_time_scale_shift,
- convnet_kernel_size=resnet_kernel_size
- )
- else:
- return get_down_block_default(
- down_block_type,
- num_layers,
- in_channels,
- out_channels,
- temb_channels,
- add_downsample,
- resnet_eps,
- resnet_act_fn,
- attn_num_head_channels,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- downsample_padding=downsample_padding,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- # resnet_kernel_size=resnet_kernel_size
- )
-
-
- class LoRACrossAttnProcessor(nn.Module):
- def __init__(
- self,
- hidden_size,
- cross_attention_dim=None,
- rank=4,
- post_add=False,
- key_states_skipped=False,
- value_states_skipped=False,
- output_states_skipped=False):
- super().__init__()
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.rank = rank
- self.post_add = post_add
-
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
- if not key_states_skipped:
- self.to_k_lora = LoRALinearLayer(
- hidden_size if post_add else (cross_attention_dim or hidden_size), hidden_size, rank)
- if not value_states_skipped:
- self.to_v_lora = LoRALinearLayer(
- hidden_size if post_add else (cross_attention_dim or hidden_size), hidden_size, rank)
- if not output_states_skipped:
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
-
- self.key_states_skipped: bool = key_states_skipped
- self.value_states_skipped: bool = value_states_skipped
- self.output_states_skipped: bool = output_states_skipped
-
- def skip_key_states(self, is_skipped: bool = True):
- if is_skipped == False:
- assert hasattr(self, 'to_k_lora')
- self.key_states_skipped = is_skipped
-
- def skip_value_states(self, is_skipped: bool = True):
- if is_skipped == False:
- assert hasattr(self, 'to_q_lora')
- self.value_states_skipped = is_skipped
-
- def skip_output_states(self, is_skipped: bool = True):
- if is_skipped == False:
- assert hasattr(self, 'to_out_lora')
- self.output_states_skipped = is_skipped
-
- def __call__(
- self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
- ):
- batch_size, sequence_length, _ = hidden_states.shape
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- query = attn.to_q(hidden_states)
- query = query + scale * self.to_q_lora(query if self.post_add else hidden_states)
- query = attn.head_to_batch_dim(query)
-
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
-
- key = attn.to_k(encoder_hidden_states)
- if not self.key_states_skipped:
- key = key + scale * self.to_k_lora(key if self.post_add else encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
- if not self.value_states_skipped:
- value = value + scale * self.to_v_lora(value if self.post_add else encoder_hidden_states)
-
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- out = attn.to_out[0](hidden_states)
- if not self.output_states_skipped:
- out = out + scale * self.to_out_lora(out if self.post_add else hidden_states)
- hidden_states = out
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
-
-
- class ControlLoRACrossAttnProcessor(LoRACrossAttnProcessor):
- def __init__(
- self,
- hidden_size,
- cross_attention_dim=None,
- rank=4,
- control_rank=None,
- post_add=False,
- concat_hidden=False,
- control_channels=None,
- control_self_add=True,
- key_states_skipped=False,
- value_states_skipped=False,
- output_states_skipped=False,
- **kwargs):
- super().__init__(
- hidden_size,
- cross_attention_dim,
- rank,
- post_add=post_add,
- key_states_skipped=key_states_skipped,
- value_states_skipped=value_states_skipped,
- output_states_skipped=output_states_skipped)
-
- control_rank = rank if control_rank is None else control_rank
- control_channels = hidden_size if control_channels is None else control_channels
- self.concat_hidden = concat_hidden
- self.control_self_add = control_self_add if control_channels is None else False
- self.control_states: torch.Tensor = None
-
- self.to_control = LoRALinearLayer(
- control_channels + (hidden_size if concat_hidden else 0),
- hidden_size,
- control_rank)
- self.pre_loras: List[LoRACrossAttnProcessor] = []
- self.post_loras: List[LoRACrossAttnProcessor] = []
-
- def inject_pre_lora(self, lora_layer):
- self.pre_loras.append(lora_layer)
-
- def inject_post_lora(self, lora_layer):
- self.post_loras.append(lora_layer)
-
- def inject_control_states(self, control_states):
- self.control_states = control_states
-
- def process_control_states(self, hidden_states, scale=1.0):
- control_states = self.control_states.to(hidden_states.dtype)
- if hidden_states.ndim == 3 and control_states.ndim == 4:
- batch, _, height, width = control_states.shape
- control_states = control_states.permute(0, 2, 3, 1).reshape(batch, height * width, -1)
- self.control_states = control_states
- _control_states = control_states
- if self.concat_hidden:
- b1, b2 = control_states.shape[0], hidden_states.shape[0]
- if b1 != b2:
- control_states = control_states[:,None].repeat(1, b2//b1, *([1]*(len(control_states.shape)-1)))
- control_states = control_states.view(-1, *control_states.shape[2:])
- _control_states = torch.cat([hidden_states, control_states], -1)
- _control_states = scale * self.to_control(_control_states)
- if self.control_self_add:
- control_states = control_states + _control_states
- else:
- control_states = _control_states
-
- return control_states
-
- def __call__(
- self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
- ):
- pre_lora: LoRACrossAttnProcessor
- post_lora: LoRACrossAttnProcessor
- assert self.control_states is not None
-
- batch_size, sequence_length, _ = hidden_states.shape
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
- query = attn.to_q(hidden_states)
- for pre_lora in self.pre_loras:
- lora_in = query if pre_lora.post_add else hidden_states
- if isinstance(pre_lora, ControlLoRACrossAttnProcessor):
- lora_in = lora_in + pre_lora.process_control_states(hidden_states, scale)
- query = query + scale * pre_lora.to_q_lora(lora_in)
- query = query + scale * self.to_q_lora((
- query if self.post_add else hidden_states) + self.process_control_states(hidden_states, scale))
- for post_lora in self.post_loras:
- lora_in = query if post_lora.post_add else hidden_states
- if isinstance(post_lora, ControlLoRACrossAttnProcessor):
- lora_in = lora_in + post_lora.process_control_states(hidden_states, scale)
- query = query + scale * post_lora.to_q_lora(lora_in)
- query = attn.head_to_batch_dim(query)
-
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
-
- key = attn.to_k(encoder_hidden_states)
- for pre_lora in self.pre_loras:
- if not pre_lora.key_states_skipped:
- key = key + scale * pre_lora.to_k_lora(key if pre_lora.post_add else encoder_hidden_states)
- if not self.key_states_skipped:
- key = key + scale * self.to_k_lora(key if self.post_add else encoder_hidden_states)
- for post_lora in self.post_loras:
- if not post_lora.key_states_skipped:
- key = key + scale * post_lora.to_k_lora(key if post_lora.post_add else encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
- for pre_lora in self.pre_loras:
- if not pre_lora.value_states_skipped:
- value = value + pre_lora.to_v_lora(value if pre_lora.post_add else encoder_hidden_states)
- if not self.value_states_skipped:
- value = value + scale * self.to_v_lora(value if self.post_add else encoder_hidden_states)
- for post_lora in self.post_loras:
- if not post_lora.value_states_skipped:
- value = value + post_lora.to_v_lora(value if post_lora.post_add else encoder_hidden_states)
-
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- out = attn.to_out[0](hidden_states)
- for pre_lora in self.pre_loras:
- if not pre_lora.output_states_skipped:
- out = out + scale * pre_lora.to_out_lora(out if pre_lora.post_add else hidden_states)
- out = out + scale * self.to_out_lora(out if self.post_add else hidden_states)
- for post_lora in self.post_loras:
- if not post_lora.output_states_skipped:
- out = out + scale * post_lora.to_out_lora(out if post_lora.post_add else hidden_states)
- hidden_states = out
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
-
-
-
-
- class ControlLoRACrossAttnProcessorV2(LoRACrossAttnProcessor):
- def __init__(
- self,
- hidden_size,
- cross_attention_dim=None,
- rank=4,
- control_rank=None,
- control_channels=None,
- **kwargs):
- super().__init__(
- hidden_size,
- cross_attention_dim,
- rank,
- post_add=False,
- key_states_skipped=True,
- value_states_skipped=True,
- output_states_skipped=False)
-
- control_rank = rank if control_rank is None else control_rank
- control_channels = hidden_size if control_channels is None else control_channels
- self.concat_hidden = True
- self.control_self_add = False
- self.control_states: torch.Tensor = None
-
- self.to_control = LoRALinearLayer(
- hidden_size + control_channels,
- hidden_size,
- control_rank)
- self.to_control_out = LoRALinearLayer(
- hidden_size + control_channels,
- hidden_size,
- control_rank)
- self.pre_loras: List[LoRACrossAttnProcessor] = []
- self.post_loras: List[LoRACrossAttnProcessor] = []
-
- def inject_pre_lora(self, lora_layer):
- self.pre_loras.append(lora_layer)
-
- def inject_post_lora(self, lora_layer):
- self.post_loras.append(lora_layer)
-
- def inject_control_states(self, control_states):
- self.control_states = control_states
-
- def process_control_states(self, hidden_states, scale=1.0, is_out=False):
- control_states = self.control_states.to(hidden_states.dtype)
- if hidden_states.ndim == 3 and control_states.ndim == 4:
- batch, _, height, width = control_states.shape
- control_states = control_states.permute(0, 2, 3, 1).reshape(batch, height * width, -1)
- self.control_states = control_states
- _control_states = control_states
- if self.concat_hidden:
- b1, b2 = control_states.shape[0], hidden_states.shape[0]
- if b1 != b2:
- control_states = control_states[:,None].repeat(1, b2//b1, *([1]*(len(control_states.shape)-1)))
- control_states = control_states.view(-1, *control_states.shape[2:])
- _control_states = torch.cat([hidden_states, control_states], -1)
- _control_states = scale * (self.to_control_out if is_out else self.to_control)(_control_states)
- if self.control_self_add:
- control_states = control_states + _control_states
- else:
- control_states = _control_states
-
- return control_states
-
- def __call__(
- self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
- ):
- pre_lora: LoRACrossAttnProcessor
- post_lora: LoRACrossAttnProcessor
- assert self.control_states is not None
-
- batch_size, sequence_length, _ = hidden_states.shape
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
- for pre_lora in self.pre_loras:
- if isinstance(pre_lora, ControlLoRACrossAttnProcessorV2):
- hidden_states = hidden_states + pre_lora.process_control_states(hidden_states, scale)
- hidden_states = hidden_states + self.process_control_states(hidden_states, scale)
- for post_lora in self.post_loras:
- if isinstance(post_lora, ControlLoRACrossAttnProcessorV2):
- hidden_states = hidden_states + post_lora.process_control_states(hidden_states, scale)
- query = attn.to_q(hidden_states)
- for pre_lora in self.pre_loras:
- lora_in = query if pre_lora.post_add else hidden_states
- query = query + scale * pre_lora.to_q_lora(lora_in)
- query = query + scale * self.to_q_lora(query if self.post_add else hidden_states)
- for post_lora in self.post_loras:
- lora_in = query if post_lora.post_add else hidden_states
- query = query + scale * post_lora.to_q_lora(lora_in)
- query = attn.head_to_batch_dim(query)
-
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
-
- key = attn.to_k(encoder_hidden_states)
- for pre_lora in self.pre_loras:
- if not pre_lora.key_states_skipped:
- key = key + scale * pre_lora.to_k_lora(key if pre_lora.post_add else encoder_hidden_states)
- if not self.key_states_skipped:
- key = key + scale * self.to_k_lora(key if self.post_add else encoder_hidden_states)
- for post_lora in self.post_loras:
- if not post_lora.key_states_skipped:
- key = key + scale * post_lora.to_k_lora(key if post_lora.post_add else encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
- for pre_lora in self.pre_loras:
- if not pre_lora.value_states_skipped:
- value = value + pre_lora.to_v_lora(value if pre_lora.post_add else encoder_hidden_states)
- if not self.value_states_skipped:
- value = value + scale * self.to_v_lora(value if self.post_add else encoder_hidden_states)
- for post_lora in self.post_loras:
- if not post_lora.value_states_skipped:
- value = value + post_lora.to_v_lora(value if post_lora.post_add else encoder_hidden_states)
-
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- for pre_lora in self.pre_loras:
- if isinstance(pre_lora, ControlLoRACrossAttnProcessorV2):
- hidden_states = hidden_states + pre_lora.process_control_states(hidden_states, scale, is_out=True)
- hidden_states = hidden_states + self.process_control_states(hidden_states, scale, is_out=True)
- for post_lora in self.post_loras:
- if isinstance(post_lora, ControlLoRACrossAttnProcessorV2):
- hidden_states = hidden_states + post_lora.process_control_states(hidden_states, scale, is_out=True)
- out = attn.to_out[0](hidden_states)
- for pre_lora in self.pre_loras:
- if not pre_lora.output_states_skipped:
- out = out + scale * pre_lora.to_out_lora(out if pre_lora.post_add else hidden_states)
- out = out + scale * self.to_out_lora(out if self.post_add else hidden_states)
- for post_lora in self.post_loras:
- if not post_lora.output_states_skipped:
- out = out + scale * post_lora.to_out_lora(out if post_lora.post_add else hidden_states)
- hidden_states = out
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
-
-
- class ConvBlock2D(nn.Module):
- def __init__(
- self,
- *,
- in_channels,
- out_channels=None,
- conv_kernel_size=3,
- dropout=0.0,
- temb_channels=512,
- groups=32,
- groups_out=None,
- pre_norm=True,
- eps=1e-6,
- non_linearity="swish",
- time_embedding_norm="default",
- kernel=None,
- output_scale_factor=1.0,
- up=False,
- down=False,
- ):
- super().__init__()
- self.pre_norm = pre_norm
- self.pre_norm = True
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.time_embedding_norm = time_embedding_norm
- self.up = up
- self.down = down
- self.output_scale_factor = output_scale_factor
-
- if groups_out is None:
- groups_out = groups
-
- self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
-
- self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=conv_kernel_size, stride=1, padding=conv_kernel_size//2)
-
- if temb_channels is not None:
- if self.time_embedding_norm == "default":
- time_emb_proj_out_channels = out_channels
- elif self.time_embedding_norm == "scale_shift":
- time_emb_proj_out_channels = out_channels * 2
- else:
- raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
-
- self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
- else:
- self.time_emb_proj = None
-
- self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
- self.dropout = torch.nn.Dropout(dropout)
-
- if non_linearity == "swish":
- self.nonlinearity = lambda x: F.silu(x)
- elif non_linearity == "mish":
- self.nonlinearity = Mish()
- elif non_linearity == "silu":
- self.nonlinearity = nn.SiLU()
-
- self.upsample = self.downsample = None
- if self.up:
- if kernel == "fir":
- fir_kernel = (1, 3, 3, 1)
- self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
- elif kernel == "sde_vp":
- self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
- else:
- self.upsample = Upsample2D(in_channels, use_conv=False)
- elif self.down:
- if kernel == "fir":
- fir_kernel = (1, 3, 3, 1)
- self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
- elif kernel == "sde_vp":
- self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
- else:
- self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
-
- def forward(self, input_tensor, temb):
- hidden_states = input_tensor
-
- hidden_states = self.norm1(hidden_states)
- hidden_states = self.nonlinearity(hidden_states)
-
- if self.upsample is not None:
- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
- if hidden_states.shape[0] >= 64:
- input_tensor = input_tensor.contiguous()
- hidden_states = hidden_states.contiguous()
- input_tensor = self.upsample(input_tensor)
- hidden_states = self.upsample(hidden_states)
- elif self.downsample is not None:
- input_tensor = self.downsample(input_tensor)
- hidden_states = self.downsample(hidden_states)
-
- hidden_states = self.conv1(hidden_states)
-
- if temb is not None:
- temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
-
- if temb is not None and self.time_embedding_norm == "default":
- hidden_states = hidden_states + temb
-
- hidden_states = self.norm2(hidden_states)
-
- if temb is not None and self.time_embedding_norm == "scale_shift":
- scale, shift = torch.chunk(temb, 2, dim=1)
- hidden_states = hidden_states * (1 + scale) + shift
-
- hidden_states = self.nonlinearity(hidden_states)
-
- output_tensor = self.dropout(hidden_states)
-
- return output_tensor
-
-
- class SimpleDownEncoderBlock2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- convnet_eps: float = 1e-6,
- convnet_time_scale_shift: str = "default",
- convnet_act_fn: str = "swish",
- convnet_groups: int = 32,
- convnet_pre_norm: bool = True,
- convnet_kernel_size: int = 3,
- output_scale_factor=1.0,
- add_downsample=True,
- downsample_padding=1,
- ):
- super().__init__()
- convnets = []
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- convnets.append(
- ConvBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=None,
- eps=convnet_eps,
- groups=convnet_groups,
- dropout=dropout,
- time_embedding_norm=convnet_time_scale_shift,
- non_linearity=convnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=convnet_pre_norm,
- conv_kernel_size=convnet_kernel_size,
- )
- )
- in_channels = in_channels if num_layers == 0 else out_channels
-
- self.convnets = nn.ModuleList(convnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
- )
- ]
- )
- else:
- self.downsamplers = None
-
- def forward(self, hidden_states):
- for convnet in self.convnets:
- hidden_states = convnet(hidden_states, temb=None)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- return hidden_states
-
-
- @dataclass
- class ControlLoRAOutput(BaseOutput):
- control_states: Tuple[torch.FloatTensor]
-
-
- class ControlLoRA(ModelMixin, ConfigMixin):
- @register_to_config
- def __init__(
- self,
- in_channels: int = 3,
- down_block_types: Tuple[str] = (
- "SimpleDownEncoderBlock2D",
- "SimpleDownEncoderBlock2D",
- "SimpleDownEncoderBlock2D",
- "SimpleDownEncoderBlock2D",
- ),
- block_out_channels: Tuple[int] = (32, 64, 128, 256),
- layers_per_block: int = 1,
- act_fn: str = "silu",
- norm_num_groups: int = 32,
- lora_pre_down_block_types: Tuple[str] = (
- None,
- "SimpleDownEncoderBlock2D",
- "SimpleDownEncoderBlock2D",
- "SimpleDownEncoderBlock2D",
- ),
- lora_pre_down_layers_per_block: int = 1,
- lora_pre_conv_skipped: bool = False,
- lora_pre_conv_types: Tuple[str] = (
- "SimpleDownEncoderBlock2D",
- "SimpleDownEncoderBlock2D",
- "SimpleDownEncoderBlock2D",
- "SimpleDownEncoderBlock2D",
- ),
- lora_pre_conv_layers_per_block: int = 1,
- lora_pre_conv_layers_kernel_size: int = 1,
- lora_block_in_channels: Tuple[int] = (256, 256, 256, 256),
- lora_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
- lora_cross_attention_dims: Tuple[List[int]] = (
- [None, 768, None, 768, None, 768, None, 768, None, 768],
- [None, 768, None, 768, None, 768, None, 768, None, 768],
- [None, 768, None, 768, None, 768, None, 768, None, 768],
- [None, 768]
- ),
- lora_rank: int = 4,
- lora_control_rank: int = None,
- lora_post_add: bool = False,
- lora_concat_hidden: bool = False,
- lora_control_channels: Tuple[int] = (None, None, None, None),
- lora_control_self_add: bool = True,
- lora_key_states_skipped: bool = False,
- lora_value_states_skipped: bool = False,
- lora_output_states_skipped: bool = False,
- lora_control_version: int = 1
- ):
- super().__init__()
-
- lora_control_cls = ControlLoRACrossAttnProcessor
- if lora_control_version == 2:
- lora_control_cls = ControlLoRACrossAttnProcessorV2
-
- assert lora_block_in_channels[0] == block_out_channels[-1]
-
- if lora_pre_conv_skipped:
- lora_control_channels = lora_block_in_channels
- lora_control_self_add = False
-
- self.layers_per_block = layers_per_block
- self.lora_pre_down_layers_per_block = lora_pre_down_layers_per_block
- self.lora_pre_conv_layers_per_block = lora_pre_conv_layers_per_block
-
- self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
-
- self.down_blocks = nn.ModuleList([])
- self.pre_lora_layers = nn.ModuleList([])
- self.lora_layers = nn.ModuleList([])
-
- # pre_down
- pre_down_blocks = []
- output_channel = block_out_channels[0]
- for i, down_block_type in enumerate(down_block_types):
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
-
- pre_down_block = get_down_block(
- down_block_type,
- num_layers=self.layers_per_block,
- in_channels=input_channel,
- out_channels=output_channel,
- add_downsample=not is_final_block,
- resnet_eps=1e-6,
- downsample_padding=0,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- attn_num_head_channels=None,
- temb_channels=None,
- )
- pre_down_blocks.append(pre_down_block)
- self.down_blocks.append(nn.Sequential(*pre_down_blocks))
- self.pre_lora_layers.append(
- get_down_block(
- lora_pre_conv_types[0],
- num_layers=self.lora_pre_conv_layers_per_block,
- in_channels=lora_block_in_channels[0],
- out_channels=(
- lora_block_out_channels[0]
- if lora_control_channels[0] is None
- else lora_control_channels[0]),
- add_downsample=False,
- resnet_eps=1e-6,
- downsample_padding=0,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- attn_num_head_channels=None,
- temb_channels=None,
- resnet_kernel_size=lora_pre_conv_layers_kernel_size,
- ) if not lora_pre_conv_skipped else nn.Identity()
- )
- self.lora_layers.append(
- nn.ModuleList([
- lora_control_cls(
- lora_block_out_channels[0],
- cross_attention_dim=cross_attention_dim,
- rank=lora_rank,
- control_rank=lora_control_rank,
- post_add=lora_post_add,
- concat_hidden=lora_concat_hidden,
- control_channels=lora_control_channels[0],
- control_self_add=lora_control_self_add,
- key_states_skipped=lora_key_states_skipped,
- value_states_skipped=lora_value_states_skipped,
- output_states_skipped=lora_output_states_skipped)
- for cross_attention_dim in lora_cross_attention_dims[0]
- ])
- )
-
- # down
- output_channel = lora_block_in_channels[0]
- for i, down_block_type in enumerate(lora_pre_down_block_types):
- if i == 0:
- continue
- input_channel = output_channel
- output_channel = lora_block_in_channels[i]
-
- down_block = get_down_block(
- down_block_type,
- num_layers=self.lora_pre_down_layers_per_block,
- in_channels=input_channel,
- out_channels=output_channel,
- add_downsample=True,
- resnet_eps=1e-6,
- downsample_padding=0,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- attn_num_head_channels=None,
- temb_channels=None,
- )
- self.down_blocks.append(down_block)
-
- self.pre_lora_layers.append(
- get_down_block(
- lora_pre_conv_types[i],
- num_layers=self.lora_pre_conv_layers_per_block,
- in_channels=output_channel,
- out_channels=(
- lora_block_out_channels[i]
- if lora_control_channels[i] is None
- else lora_control_channels[i]),
- add_downsample=False,
- resnet_eps=1e-6,
- downsample_padding=0,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- attn_num_head_channels=None,
- temb_channels=None,
- resnet_kernel_size=lora_pre_conv_layers_kernel_size,
- ) if not lora_pre_conv_skipped else nn.Identity()
- )
- self.lora_layers.append(
- nn.ModuleList([
- lora_control_cls(
- lora_block_out_channels[i],
- cross_attention_dim=cross_attention_dim,
- rank=lora_rank,
- control_rank=lora_control_rank,
- post_add=lora_post_add,
- concat_hidden=lora_concat_hidden,
- control_channels=lora_control_channels[i],
- control_self_add=lora_control_self_add,
- key_states_skipped=lora_key_states_skipped,
- value_states_skipped=lora_value_states_skipped,
- output_states_skipped=lora_output_states_skipped)
- for cross_attention_dim in lora_cross_attention_dims[i]
- ])
- )
-
- def forward(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[ControlLoRAOutput, Tuple]:
- lora_layer: ControlLoRACrossAttnProcessor
-
- orig_dtype = x.dtype
- dtype = self.conv_in.weight.dtype
-
- h = x.to(dtype)
- h = self.conv_in(h)
- control_states_list = []
-
- # down
- for down_block, pre_lora_layer, lora_layer_list in zip(
- self.down_blocks, self.pre_lora_layers, self.lora_layers):
- h = down_block(h)
- control_states = pre_lora_layer(h)
- if isinstance(control_states, tuple):
- control_states = control_states[0]
- control_states = control_states.to(orig_dtype)
- for lora_layer in lora_layer_list:
- lora_layer.inject_control_states(control_states)
- control_states_list.append(control_states)
-
- if not return_dict:
- return tuple(control_states_list)
-
- return ControlLoRAOutput(control_states=tuple(control_states_list))
|