|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import math
- from custom_models import *
- import FrEIA.framework as Ff
- import FrEIA.modules as Fm
- from FrEIA.Decoder import *
- import timm
- import mindspore
- from mindspore import Tensor
- import numpy as np
-
-
- def positionalencoding2d(D, H, W):
- """
- :param D: dimension of the model
- :param H: H of the positions
- :param W: W of the positions
- :return: DxHxW position matrix
- """
- if D % 4 != 0:
- raise ValueError("Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(D))
- P=Tensor(np.zeros([D, H, W]), mindspore.float32)
- D = D // 2
- div_term = mindspore.numpy.exp(mindspore.numpy.arange(0.0, D, 2)* -(math.log(1e4) / D))
- pos_w = mindspore.numpy.arange(0.0, W).expand_dims(1)
- pos_h = mindspore.numpy.arange(0.0, H).expand_dims(1)
- P[0:D:2, :, :] = mindspore.numpy.sin(pos_w * div_term).transpose().expand_dims(1).repeat(H,axis=1)
- P[1:D:2, :, :] = mindspore.numpy.cos(pos_w * div_term).transpose().expand_dims(1).repeat(H,axis=1)
- P[D::2, :, :] = mindspore.numpy.sin(pos_h * div_term).transpose().expand_dims(2).repeat(W,axis=2)
- P[D+1::2,:, :] = mindspore.numpy.cos(pos_h * div_term).transpose().expand_dims(2).repeat(W,axis=2)
- return P
-
-
- def subnet_fc(dims_in, dims_out):
- return mindspore.nn.SequentialCell(mindspore.nn.Dense(dims_in, 2 * dims_in), mindspore.nn.ReLU(), mindspore.nn.Dense(2 * dims_in, dims_out))
-
- def freia_flow_head(c, n_feat):
- coder = Ff.SequenceINN(n_feat)
- print('NF coder:', n_feat)
- for k in range(c.coupling_blocks):
- coder.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, affine_clamping=c.clamp_alpha,
- global_affine_type='SOFTPLUS', permute_soft=True)
- return coder
-
- def freia_cflow_head(c, n_feat):
- n_cond = c.condition_vec
- coder = Ff.SequenceINN(n_feat)
- print('CNF coder:', n_feat)
- for k in range(c.coupling_blocks):
- coder.append(Fm.AllInOneBlock, cond=0, cond_shape=(n_cond,), subnet_constructor=subnet_fc, affine_clamping=c.clamp_alpha,
- global_affine_type='SOFTPLUS', permute_soft=True)
- return coder
-
- def load_decoder_arch(c, dim_in):
- if c.dec_arch == 'freia-flow':
- decoder = freia_flow_head(c, dim_in)
- elif c.dec_arch == 'freia-cflow':
- decoder = freia_cflow_head(c, dim_in)
- else:
- raise NotImplementedError('{} is not supported NF!'.format(c.dec_arch))
- #decoder.layer3.5.bn1.to_float(ms.float16)
- decoder.add_flags_recursive(fp16=True)
- return decoder
-
- def load_new_decoder_arch(c, dim_in):
- return CFLOWDecoder(c,dim_in)
-
- activation = {}
- def get_activation(name):
- def hook(model, input, output):
- activation[name] = output
- return hook
-
-
- def load_encoder_arch(c, L):
- # encoder pretrained on natural images:
- pool_cnt = 0
- pool_dims = list()
- pool_layers = ['layer'+str(i) for i in range(L)]
- if 'resnet' in c.enc_arch:
- if c.enc_arch == 'resnet18':
- encoder = resnet18(pretrained=True, progress=True, L=L)
- elif c.enc_arch == 'resnet34':
- encoder = resnet34(pretrained=True, progress=True, L=L)
- elif c.enc_arch == 'resnet50':
- encoder = resnet50(pretrained=True, progress=True, L=L)
- elif c.enc_arch == 'resnext50_32x4d':
- encoder = resnext50_32x4d(pretrained=True, progress=True, L=L)
- elif c.enc_arch == 'wide_resnet50_2':
- encoder = wide_resnet50_2(pretrained=True, progress=True, L=L)
- else:
- raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
- #
- if L >= 3:
- if 'wide' in c.enc_arch:
- pool_dims.append(encoder.layer2[-1].conv3.out_channels)
- else:
- pool_dims.append(encoder.layer2[-1].conv2.out_channels)
- pool_cnt = pool_cnt + 1
- if L >= 2:
- if 'wide' in c.enc_arch:
- pool_dims.append(encoder.layer3[-1].conv3.out_channels)
- else:
- pool_dims.append(encoder.layer3[-1].conv2.out_channels)
- pool_cnt = pool_cnt + 1
- if L >= 1:
- if 'wide' in c.enc_arch:
- pool_dims.append(encoder.layer4[-1].conv3.out_channels)
- else:
- pool_dims.append(encoder.layer4[-1].conv2.out_channels)
- pool_cnt = pool_cnt + 1
- elif 'vit' in c.enc_arch:
- if c.enc_arch == 'vit_base_patch16_224':
- encoder = timm.create_model('vit_base_patch16_224', pretrained=True)
- elif c.enc_arch == 'vit_base_patch16_384':
- encoder = timm.create_model('vit_base_patch16_384', pretrained=True)
- else:
- raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
- #
- if L >= 3:
- encoder.blocks[10].register_backward_hook(get_activation(pool_layers[pool_cnt]))
- pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
- pool_cnt = pool_cnt + 1
- if L >= 2:
- encoder.blocks[2].register_backward_hook(get_activation(pool_layers[pool_cnt]))
- pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
- pool_cnt = pool_cnt + 1
- if L >= 1:
- encoder.blocks[6].register_backward_hook(get_activation(pool_layers[pool_cnt]))
- pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
- pool_cnt = pool_cnt + 1
- elif 'efficient' in c.enc_arch:
- if 'b5' in c.enc_arch:
- encoder = timm.create_model(c.enc_arch, pretrained=True)
- blocks = [-2, -3, -5]
- else:
- raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
- #
- if L >= 3:
- pool_dims.append(encoder.blocks[blocks[2]][-1].bn3.num_features)
- pool_cnt = pool_cnt + 1
- if L >= 2:
- pool_dims.append(encoder.blocks[blocks[1]][-1].bn3.num_features)
- pool_cnt = pool_cnt + 1
- if L >= 1:
- pool_dims.append(encoder.blocks[blocks[0]][-1].bn3.num_features)
- pool_cnt = pool_cnt + 1
- elif 'mobile' in c.enc_arch:
- if c.enc_arch == 'mobilenet_v3_small':
- encoder = mobilenet_v3_small(pretrained=False, progress=True)
- blocks = [-2, -5, -10]
- elif c.enc_arch == 'mobilenet_v3_large':
- encoder = mobilenet_v3_large(pretrained=False, progress=True)
- blocks = [-2, -5, -11]
- else:
- raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
- #
- if L >= 3:
- pool_dims.append(encoder.features[blocks[2]].block[-1][-3].out_channels)
- pool_cnt = pool_cnt + 1
- if L >= 2:
- pool_dims.append(encoder.features[blocks[1]].block[-1][-3].out_channels)
- pool_cnt = pool_cnt + 1
- if L >= 1:
- pool_dims.append(encoder.features[blocks[0]].block[-1][-3].out_channels)
- pool_cnt = pool_cnt + 1
- else:
- raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
- #
- encoder.add_flags_recursive(fp16=True)
- return encoder, pool_layers, pool_dims
|