|
- # import functools
- import tensorflow as tf
- # import torch
- # import torch.nn as nn
- # import torch.nn.functional as F
-
- class Encoder(tf.keras.Model):
- def __init__(self, base_channel, num_layers, #, data_dim
- nonlinearity, norm_type='batch_norm',
- max_channel=1024):
- super(Encoder, self).__init__()
- assert norm_type in ['instance_norm', 'batch_norm']
-
- # self.convs = nn.ModuleList()
- self.convs = []
- # self.convs.append(nn.Conv3d(data_dim, base_channel, 3, padding=1), )
- self.convs.append(tf.keras.layers.Conv3D(base_channel, (3,3,3), padding='same', data_format='channels_first', use_bias=True))
-
- for i in range(num_layers):
- # in_channels = min(base_channel * (2 ** i), max_channel)
- out_channels = min(base_channel * (2 ** (i+1)), max_channel)
- if norm_type == 'instance_norm':
- # norm_layer = functools.partial(nn.InstanceNorm3d, affine=True)#固定了某个参数的值,使原来的调用少传一个参数
- norm_layer1 = 0 #没有用到
- norm_layer2 = 0 #没有用到
- elif norm_type == 'batch_norm':
- # norm_layer = functools.partial(nn.BatchNorm3d, affine=True)
- norm_layer1 = tf.keras.layers.BatchNormalization(axis=1)
- norm_layer2 = tf.keras.layers.BatchNormalization(axis=1) #对这种带参数的类,必须要放到for循环里面,且有2个,不然下面调用时会以为是同一个对象
-
- # self.convs.append(nn.Sequential(
- # nonlinearity,
- # norm_layer(in_channels),
- # nn.Conv3d(in_channels, out_channels, 3, stride=2, padding=1),
-
- # nonlinearity,
- # norm_layer(out_channels),
- # nn.Conv3d(out_channels, out_channels, 3, padding=1),
- # ))
- self.convs.append(tf.keras.Sequential([nonlinearity,norm_layer1,
- tf.keras.layers.Conv3D(out_channels, (3,3,3), strides=2, padding='same', data_format='channels_first', use_bias=True),
- nonlinearity,norm_layer2,
- tf.keras.layers.Conv3D(out_channels, (3,3,3), padding='same', data_format='channels_first', use_bias=True)]))
-
- self.out_channels = out_channels
-
- # def forward(self, x):
- def call(self, x):
- result = []
- for conv in self.convs:
- x = conv(x)
- result.append(x)
- return result[::-1] #返回倒序的原list
-
- class MultiScaleDecoder(tf.keras.Model):
- def __init__(self, base_channel, num_layers, nonlinearity, norm_type='batch_norm', dropout_ratio=0.5):
- super(MultiScaleDecoder, self).__init__()
- assert norm_type in ['instance_norm', 'batch_norm']
-
- self.dropout_ratio = dropout_ratio
-
- # in_channels = base_channel * (2 ** num_layers)
- out_channels = base_channel * (2 ** (num_layers-1))
-
- # self.input = nn.Sequential(
- # nonlinearity,
- # norm_layer(in_channels),
- # nn.ConvTranspose3d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
-
- # nonlinearity,
- # norm_layer(out_channels),
- # nn.Conv3d(out_channels, out_channels, 3, padding=1),
- # )
- if norm_type == 'instance_norm':
- # norm_layer = functools.partial(nn.InstanceNorm3d, affine=True)#固定了某个参数的值,使原来的调用少传一个参数
- norm_layer1 = 0 #没有用到
- norm_layer2 = 0 #没有用到
- elif norm_type == 'batch_norm':
- # norm_layer = functools.partial(nn.BatchNorm3d, affine=True)
- norm_layer1 = tf.keras.layers.BatchNormalization(axis=1)
- norm_layer2 = tf.keras.layers.BatchNormalization(axis=1)
-
- self.input_process = tf.keras.Sequential([nonlinearity,norm_layer1,
- tf.keras.layers.Conv3DTranspose(out_channels, 3, strides=2, padding='same', output_padding=None,data_format='channels_first'),
- nonlinearity,norm_layer2,
- tf.keras.layers.Conv3D(out_channels, (3,3,3), padding='same', data_format='channels_first', use_bias=True)])
-
- # self.convs = nn.ModuleList()
- self.convs = []
- # self.to_outputs = nn.ModuleList()
- self.to_outputs = []
-
- for i in range(num_layers-1)[::-1]:
- # in_channels = base_channel * (2 ** (i+1)) * 2
- out_channels = base_channel * (2 ** i)
-
- if norm_type == 'instance_norm':
- # norm_layer = functools.partial(nn.InstanceNorm3d, affine=True)#固定了某个参数的值,使原来的调用少传一个参数
- norm_layer1 = 0 #没有用到
- norm_layer2 = 0 #没有用到
- elif norm_type == 'batch_norm':
- # norm_layer = functools.partial(nn.BatchNorm3d, affine=True)
- norm_layer1 = tf.keras.layers.BatchNormalization(axis=1)
- norm_layer2 = tf.keras.layers.BatchNormalization(axis=1)
-
- # self.convs.append(nn.Sequential(
- # nonlinearity,
- # norm_layer(in_channels),
- # nn.ConvTranspose3d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
-
- # nonlinearity,
- # norm_layer(out_channels),
- # nn.Conv3d(out_channels, out_channels, 3, padding=1),
- # ))
- self.convs.append(tf.keras.Sequential([nonlinearity,norm_layer1,
- tf.keras.layers.Conv3DTranspose(out_channels, 3, strides=2, padding='same', output_padding=None,data_format='channels_first'),
- nonlinearity,norm_layer2,
- tf.keras.layers.Conv3D(out_channels, (3,3,3), padding='same', data_format='channels_first', use_bias=True)]))
-
- to_output = ToOutput(nonlinearity)
- self.to_outputs.append(to_output)
-
- # self.sigmoid = nn.Sigmoid()
- self.sigmoid = tf.keras.activations.sigmoid
- # self.dropout = nn.Dropout(self.dropout_ratio)
- self.dropout = tf.keras.layers.Dropout(self.dropout_ratio)
-
- # def forward(self, features):
- def call(self, features, training_flag=True):
- x = self.input_process(features[0])
- skip, index = None, 0
- for conv, to_output in zip(self.convs, self.to_outputs):
- index += 1
- x = tf.keras.layers.concatenate([x, self.dropout(features[index], training=training_flag)], axis=1)
- x = conv(x)
- skip = to_output(x, skip)
- out = self.sigmoid(skip)
- return out
-
- def multi_scale_output(self, features, training_flag=True):
- x = self.input_process(features[0])
- skip, index = None, 0
- multi_scale_output = []
- for conv, to_output in zip(self.convs, self.to_outputs):
- index += 1
- x = tf.keras.layers.concatenate([x, self.dropout(features[index], training=training_flag)], axis=1)
- x = conv(x)
- skip = to_output(x, skip)
- multi_scale_output.append(self.sigmoid(skip))
- out = self.sigmoid(skip)
- multi_scale_output.append(out)
- return multi_scale_output
-
- class ToOutput(tf.keras.Model):
- def __init__(self, nonlinearity): #, input_channel
- super(ToOutput, self).__init__()
- self.nonlinearity = nonlinearity
- # self.conv = nn.Conv3d(input_channel, 1, 3, padding=1)
- self.conv = tf.keras.layers.Conv3D(1, (3,3,3), padding='same', data_format='channels_first', use_bias=True)
- self.upconv = tf.keras.layers.Conv3DTranspose(1, 3, strides=2, padding='same', output_padding=None,data_format='channels_first')
- # def forward(self, x, skip=None):
- def call(self, x, skip=None):
- out = self.nonlinearity(x)
- out = self.conv(out) #out.shape: (32, 1, 16, 16, 16) (32, 1, 32, 32, 32) (32, 1, 64, 64, 64)多次执行会有不一样的结果
- if skip is not None:
- # print("skip.shape:",skip.shape) #skip.shape: (32, 1, 16, 16, 16) (32, 1, 32, 32, 32)
- # skip = F.interpolate(skip, scale_factor=2)#TensorFlow没有3D插值,用转置卷积代替
- skip = self.upconv(skip)
- # print("skip.shape:",skip.shape) #skip.shape: (32, 1, 32, 32, 32) (32, 1, 64, 64, 64)
- # print("out.shape:",out.shape) #out.shape: (32, 1, 32, 32, 32) (32, 1, 64, 64, 64)
- out = out + skip
- return out
-
- class Generator(tf.keras.Model):
- def __init__(self, base_channel, num_layers):
- super(Generator, self).__init__()
- # data_dim = 1
- # nonlinearity = nn.LeakyReLU(0.1)
- nonlinearity = tf.keras.layers.LeakyReLU(alpha=0.1)
- self.encoder = Encoder(base_channel, num_layers, nonlinearity) #data_dim,
- self.decoder = MultiScaleDecoder(base_channel, num_layers, nonlinearity)
-
- # def forward(self, x):
- def call(self, x, training_flag=True):
- out = self.decoder(self.encoder(x), training_flag=training_flag)
- return out
-
- def multi_scale_output(self, x, training_flag=True):
- features = self.encoder(x)
- multi_scale_output = self.decoder.multi_scale_output(features, training_flag=training_flag)
- return multi_scale_output
|