|
- # -*- coding: utf-8 -*-
- """
- @Author : zhwzhong
- @License : (C) Copyright 2013-2018, hit
- @Contact : zhwzhong@hit.edu.cn
- @Software: PyCharm
- @File : chop_forward.py
- @Time : 2021/10/25 21:12
- @Desc :
- """
- # -*- coding: utf-8 -*-
- """
- @Author : zhwzhong
- @License : (C) Copyright 2013-2018, hit
- @Contact : zhwzhong.hit@gmail.com
- @Software: PyCharm
- @File : chop_foward.py
- @Time : 2019/12/15 11:21
- @Desc :
- """
- import copy
- import torch
- import torch.utils.data as utils
- from torch.autograd import Variable
-
-
- class ImageSplitter:
- # key points:
- # Boarder padding and over-lapping img splitting to avoid the instability of edge value
- # Thanks Waifu2x's autorh nagadomi for suggestions (https://github.com/nagadomi/waifu2x/issues/238)
-
- def __init__(self, patch_size, scale_factor, stride, num_channels):
-
- self.width = 0
- self.height = 0
- self.stride = stride
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.scale_factor = scale_factor
-
- def split_img_tensor(self, img_tensor):
- # resize image and convert them into tensor
- batch, channel, height, width = img_tensor.size()
- self.height = height
- self.width = width
-
- side = min(height, width, self.patch_size)
- delta = self.patch_size - side
- Z = torch.zeros([batch, channel, height+delta, width+delta])
- Z[:, :, delta//2:height+delta//2, delta//2:width+delta//2] = img_tensor
- batch, channel, new_height, new_width = Z.size()
-
- patch_box = []
-
- # split image into over-lapping pieces
- for i in range(0, new_height, self.stride):
- for j in range(0, new_width, self.stride):
- x = min(new_height, i + self.patch_size)
- y = min(new_width, j + self.patch_size)
- part = Z[:, :, x-self.patch_size:x, y-self.patch_size:y]
-
- patch_box.append(part)
-
- patch_tensor = torch.cat(patch_box, dim=0)
- return patch_tensor
-
- def merge_img_tensor(self, list_img_tensor):
- img_tensors = copy.copy(list_img_tensor)
-
- patch_size = self.patch_size * self.scale_factor
- stride = self.stride * self.scale_factor
- height = self.height * self.scale_factor
- width = self.width * self.scale_factor
- side = min(height, width, patch_size)
- delta = patch_size - side
- new_height = delta + height
- new_width = delta + width
- out = torch.zeros((1, self.num_channels, new_height, new_width))
- mask = torch.zeros((1, self.num_channels, new_height, new_width))
-
- for i in range(0, new_height, stride):
- for j in range(0, new_width, stride):
- x = min(new_height, i + patch_size)
- y = min(new_width, j + patch_size)
- mask_patch = torch.zeros((1, self.num_channels, new_height, new_width))
- out_patch = torch.zeros((1, self.num_channels, new_height, new_width))
- mask_patch[:, :, (x - patch_size):x, (y - patch_size):y] = 1.0
- out_patch[:, :, (x - patch_size):x, (y - patch_size):y] = img_tensors.pop(0)
- mask = mask + mask_patch
- out = out + out_patch
-
- out = out / mask
-
- out = out[:, :, delta//2:new_height - delta//2, delta//2:new_width - delta//2]
-
- return out
-
-
- def chop_forward(img, network, patch_size, scale_factor, stride, num_channels, device):
- # BatchSize 只能为1, channel 只能为3
- img_splitter = ImageSplitter(patch_size=patch_size, scale_factor=scale_factor, stride=stride, num_channels=num_channels)
- img_patch = img_splitter.split_img_tensor(img)
- # print(img_patch.shape)
- testset = utils.TensorDataset(img_patch)
- test_dataloader = utils.DataLoader(testset, num_workers=0, drop_last=False, batch_size=64, shuffle=False)
- out_box = []
-
-
- for iteration, batch in enumerate(test_dataloader, 1):
- inputs = Variable(batch[0]).to(device)
-
- with torch.no_grad():
- prediction = network(inputs)
-
-
- for j in range(prediction.shape[0]):
- out_box.append(prediction[j,:,:,:])
-
-
- return img_splitter.merge_img_tensor(out_box).to(device)
-
- # device_id = torch.cuda.current_device()
- # net = torch.nn.Conv2d(in_channels=6, out_channels=6, kernel_size=1).to(device_id)
- # arr =torch.randn(2, 6, 512, 512).to(device_id)
- # out = chop_forward(arr, net, 32, 1, 16, device=device_id, num_channels=6)
- # print(out.size())
|