|
- import argparse
- import os
-
- import numpy as np
- import torch
- import torch.nn as nn
- from PIL import Image
- from torch.autograd import Variable
- from torch.utils.data import DataLoader, Dataset
- from torchvision import transforms
- from torchvision.datasets import ImageFolder
- import Model.model as model
- from Model.context_model import Weighted_Gaussian
- import time
- from traindata import Traindataset
-
-
- def codebook_range(model_index,train_loader):
-
- M, N2 = 192, 128
- if (model_index == 6) or (model_index == 7) or (model_index == 14) or (model_index == 15):
- M, N2 = 256, 192
-
- image_comp = model.Image_coding_multi_hyper(3, M, N2, M, M // 2)
- context = Weighted_Gaussian(M)
-
- #model_existed = os.path.exists(os.path.join(args.out_dir,'SQLmse.pkl')) and False
- model_dir = '/userdata/Weights.zip/Weights/'
- models = ["mse200", "mse400", "mse800", "mse1600", "mse3200", "mse6400", "mse12800", "mse25600",
- "msssim4", "msssim8", "msssim16", "msssim32", "msssim64", "msssim128", "msssim320", "msssim640"]
-
-
- image_comp.load_state_dict(torch.load(
- os.path.join(model_dir, models[model_index] + r'.pkl')))
- context.load_state_dict(torch.load(
- os.path.join(model_dir, models[model_index] + r'p.pkl')))
-
- image_comp.cuda()
- context.cuda()
-
- print(models[model_index]+' model resumed')
-
- for epoch in range(1):
-
- prior_min=np.zeros(M,dtype=np.int)
- prior_max=np.zeros(M,dtype=np.int)
- t0=time.time()
-
- for step, batch_x in enumerate(train_loader):
-
- batch_x = Variable(batch_x).cuda()
-
- with torch.no_grad():
- feature,_,_ = image_comp.encoder(batch_x)
-
- for ci in range(M):
- current_channel = torch.reshape(feature[:, ci, :, :], [-1]).cpu().numpy().tolist()
- min_c = min(current_channel)
- max_c = max(current_channel)
- if min_c<prior_min[ci]:
- prior_min[ci]=min_c
- if max_c>prior_max[ci]:
- prior_max[ci]=max_c
-
- with open(os.path.join(args.out_dir, str(models[model_index])+'prior.txt'), 'w') as fd:
- for ci in range(M):
- fd.write(str(prior_min[ci])+'\t'+str(prior_max[ci])+'\n')
- fd.close()
-
- print(str(models[model_index])+' finished')
-
-
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
-
- parser.add_argument('--out_dir', type=str, default='output0')
-
- args = parser.parse_args()
- print(args)
-
- root = '/userdata'
-
- train_data = Traindataset(root)
- print(len(train_data))
- train_loader = DataLoader(train_data, batch_size=1,
- shuffle=True, num_workers=8)
-
-
- for model_idx in range(16):
- codebook_range(model_idx,train_loader)
|