|
- import os
- import numpy as np
- import mindspore
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore import context, Tensor
- from mindspore.common.initializer import initializer, HeNormal
- import torch
- import timm
- from timm.data import resolve_data_config
- from typing import Type, Any, Callable, Union, List, Optional
- from model import load_decoder_arch, load_encoder_arch, positionalencoding2d, activation
- from custom_models import *
- from config import get_args
- import os, random, time, math
-
-
- def init_seeds(seed=0):
- random.seed(seed)
- np.random.seed(seed)
- #torch.manual_seed(seed)
- #torch.cuda.manual_seed(seed)
- #torch.cuda.manual_seed_all(seed)
- mindspore.set_seed(seed)
-
- N = 256
- def testTrain(c):
- #net = wide_resnet50_2(pretrained=False, progress=True)
- #print(net.parameters_dict())
- print("------------------------")
- L = c.pool_layers
- encoder, pool_layers, pool_dims = load_encoder_arch(c, L)
- for param_tensor in encoder.parameters_dict():
- print(param_tensor, 't' )
- print(encoder.parameters_dict()[param_tensor])
- decoders = [load_decoder_arch(c, pool_dim) for pool_dim in pool_dims]
-
- print('decoder0---------------------')
- print(type(decoders[0].parameters_dict()))
- for param_tensor in decoders[0].parameters_dict():
- print(param_tensor, 't', decoders[0].parameters_dict()[param_tensor])
- print('decoder1---------------------')
- print(type(decoders[0].parameters_dict()))
- for param_tensor in decoders[0].parameters_dict():
- print(param_tensor, 't', decoders[0].parameters_dict()[param_tensor])
- print('decoder2---------------------')
- print(type(decoders[0].parameters_dict()))
- for param_tensor in decoders[0].parameters_dict():
- print(param_tensor, 't', decoders[0].parameters_dict()[param_tensor])
- '''
- L = c.pool_layers # number of pooled layers
- encoder, pool_layers, pool_dims = load_encoder_arch(c, L)
- image = torch.tensor(np.ones([c.bs, 3, c.input_size, c.input_size],dtype=np.float32))
- print(image.size())
- _ = encoder(image)
- for l, layer in enumerate(pool_layers):
- print(l)
- print(layer)
- print(activation[layer].size())
- B, C, H, W = activation[layer].size()
- S = H * W
- E = B * S
- FIB = E // N
- print(FIB)
- print("---------------------------------------")
-
- # for item in net.get_parameters():
- # print(item)
- '''
-
- def main(c):
- # model
- if c.action_type in ['norm-train', 'norm-test']:
- c.model = "{}_{}_{}_pl{}_cb{}_inp{}_run{}_{}".format(
- c.dataset, c.enc_arch, c.dec_arch, c.pool_layers, c.coupling_blocks, c.input_size, c.run_name, c.class_name)
- else:
- raise NotImplementedError('{} is not supported action-type!'.format(c.action_type))
- # image
- if ('vit' in c.enc_arch) or ('efficient' in c.enc_arch):
- encoder = timm.create_model(c.enc_arch, pretrained=True)
- arch_config = resolve_data_config({}, model=encoder)
- c.norm_mean, c.norm_std = list(arch_config['mean']), list(arch_config['mean'])
- c.img_size = arch_config['input_size'][1:] # HxW format
- c.crp_size = arch_config['input_size'][1:] # HxW format
- else:
- c.img_size = (c.input_size, c.input_size) # HxW format
- c.crp_size = (c.input_size, c.input_size) # HxW format
- if c.dataset == 'stc':
- c.norm_mean, c.norm_std = 3 * [0.5], 3 * [0.225]
- else:
- c.norm_mean, c.norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
- #
- c.img_dims = [3] + list(c.img_size)
- # network hyperparameters
- c.clamp_alpha = 1.9 # see paper equation 2 for explanation
- c.condition_vec = 128
- c.dropout = 0.0 # dropout in s-t-networks
- c.bs = c.batch_size
- # dataloader parameters
- if c.dataset == 'mvtec':
- c.data_path = '/mass_store/dataset/MVTec/'
- elif c.dataset == 'stc':
- c.data_path = './data/STC/shanghaitech'
- elif c.dataset == 'video':
- c.data_path = c.video_path
- else:
- raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
- # output settings
- c.verbose = True
- c.hide_tqdm_bar = True
- c.save_results = True
- # unsup-train
- c.print_freq = 2
- c.temp = 0.5
- c.lr_decay_epochs = [i * c.meta_epochs // 100 for i in [50, 75, 90]]
- print('LR schedule: {}'.format(c.lr_decay_epochs))
- c.lr_decay_rate = 0.1
- c.lr_warm_epochs = 2
- c.lr_warm = True
- c.lr_cosine = True
- if c.lr_warm:
- c.lr_warmup_from = c.lr / 10.0
- if c.lr_cosine:
- eta_min = c.lr * (c.lr_decay_rate ** 3)
- c.lr_warmup_to = eta_min + (c.lr - eta_min) * (
- 1 + math.cos(math.pi * c.lr_warm_epochs / c.meta_epochs)) / 2
- else:
- c.lr_warmup_to = c.lr
- ########
- os.environ['CUDA_VISIBLE_DEVICES'] = c.gpu
- c.use_cuda = not c.no_cuda and torch.cuda.is_available()
- context.set_context(mode=context.PYNATIVE_MODE, device_target=c.device_target)
- context.set_context(device_id=c.device_id)
- init_seeds(seed=int(time.time()))
- #c.device = torch.device("cuda" if c.use_cuda else "cpu")
- # selected function:
- if c.action_type in ['norm-train', 'norm-test']:
- testTrain(c)
- else:
- raise NotImplementedError('{} is not supported action-type!'.format(c.action_type))
-
-
- if __name__ == '__main__':
- c = get_args()
- main(c)
- '''
- class MyCell(nn.Cell):
-
- def __init__(self):
- super(MyCell, self).__init__(auto_prefix=False)
- self.conv2d=mindspore.nn.Conv2d(2, 4, kernel_size=3, stride=1)
- self.bn=mindspore.nn.BatchNorm2d(4)
- self.relu = ops.ReLU()
-
- def construct(self, x):
-
- x = self.conv2d(x)
- x = self.bn(x)
- return self.relu(x)
- x=mindspore.Tensor(np.ones([1, 2, 2, 3]), mindspore.float32)
- net=MyCell()
- for m in net.cells():
- if isinstance(m, mindspore.nn.Conv2d):
- # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- m.weight = initializer(HeNormal(mode='fan_out', nonlinearity='relu'), m.weight.shape, m.weight.dtype)
- #print(m.trainable_params())
- elif isinstance(m, (mindspore.nn.BatchNorm2d, mindspore.nn.GroupNorm)):
- # nn.init.constant_(m.weight, 1)
- m.gamma = initializer(1, m.gamma.shape, m.gamma.dtype)
- # nn.init.constant_(m.bias, 0)
- m.beta = initializer(0, m.beta.shape, m.beta.dtype)
- #print(m.trainable_params())
- '''
- '''
- checkpoint_dir="C:/Users/10456/Desktop/javacode"
- encoderpre="Array"
- decoderpre="File"
- encodermodelname=os.path.join(checkpoint_dir,encoderpre)
- decodermodelname=os.path.join(checkpoint_dir,encoderpre)
- encoderpath=""
- decoderspathlist=[]
- for file in os.listdir(checkpoint_dir):
- if file.startswith(encoderpre):
- encoderpath=file
- print(encoderpath)
- elif file.startswith(decoderpre):
- decoderspathlist.append(file)
- print(file)
- '''
|