|
- '''
- Contributors: Aolin Feng, Jianpin Lin, Dezhao Wang, Yueyu Hu, Tong Chen, Chuanmin Jia, Yihang Chen
- '''
-
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as f
- from torch.distributions.uniform import Uniform
-
- from Model.basic_module import Non_local_Block, ResBlock, ScalingNet, P_NN
- from Model.context_model import P_Model, Weighted_Gaussian
- from Model.factorized_entropy_model import Entropy_bottleneck
- from Model.gaussian_entropy_model import Distribution_for_entropy
-
- # Multi-hyper Model
- from Model.hyper_module import h_analysisTransform, h_synthesisTransform
-
- from Util.config import dict
- USE_VR_MODEL = dict['USE_VR_MODEL']
- USE_MULTI_HYPER = dict['USE_MULTI_HYPER']
-
- class lmd_map(nn.Module):
- def __init__(self,latent_channels):
- super(lmd_map, self).__init__()
-
- self.lc = int(latent_channels)
- self.fcn1 = nn.Linear(1, self.lc)
- self.fcn2 = nn.Linear(self.lc,self.lc)
- self.fcn3 = nn.Linear(self.lc, self.lc)
- nn.init.constant(self.fcn3.weight,0)
- nn.init.constant(self.fcn3.bias, 4) #in order that the initial sigmoid=1
-
- def forward(self, y):
-
- y1 = self.fcn1(y)
- y1 = f.relu(y1)
- y2 = self.fcn2(y1)
- y2 = f.relu(y2)
- y3 = self.fcn3(y2)
-
- return f.sigmoid(y3)
-
- class modnet(nn.Module):
- def __init__(self,in_channels):
- super(modnet, self).__init__()
-
- self.m = int(in_channels)
- self.cnn = nn.Conv2d(self.m,self.m,5,1,2)
- self.out = nn.Conv2d(2*self.m,self.m,1,1,0)
- self.lmd_map = lmd_map(2*self.m)
- nn.init.constant(self.out.weight, 0)
- nn.init.constant(self.out.bias, 4)
-
- def forward(self, latent, p, lmd):
- p = torch.log(p)
- b = latent.size()[0]
- l = self.cnn(latent)
- lmd_info = self.lmd_map(lmd)
- lmd_info = torch.reshape(lmd_info, (b, 2*self.m, 1, 1))
- mixture = torch.cat((l,p),dim=1)
- #print(mixture.size(),lmd_info.size())
- mixture = mixture * lmd_info
-
- mask = self.out(mixture)
- return f.sigmoid(mask)
-
- class Enc(nn.Module):
- def __init__(self, num_features, N1, N2, M, M1):
- #input_features = 3, N1 = 192, N2 = 128, M = 192, M1 = 96
- super(Enc, self).__init__()
- self.N1 = int(N1)
- self.N2 = int(N2)
- self.M = int(M)
- self.M1 = int(M1)
- self.n_features = int(num_features)
-
- self.conv1 = nn.Conv2d(self.n_features, self.M1, 5, 1, 2)
- self.trunk1 = nn.Sequential(ResBlock(self.M1, self.M1, 3, 1, 1), ResBlock(
- self.M1, self.M1, 3, 1, 1), nn.Conv2d(self.M1, 2*self.M1, 5, 2, 2))
- if USE_VR_MODEL:
- self.scaling1 = ScalingNet(2 * self.M1)
-
- self.down1 = nn.Conv2d(2*self.M1, self.M, 5, 2, 2)
- self.trunk2 = nn.Sequential(ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(2*self.M1, self.M1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(
- 2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- nn.Conv2d(2*self.M1, 2*self.M1, 1, 1, 0))
- if USE_VR_MODEL:
- self.scaling2 = ScalingNet(self.M)
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
- if USE_VR_MODEL:
- self.scaling3 = ScalingNet(self.M)
-
- self.trunk4 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
- if USE_VR_MODEL:
- self.scaling4 = ScalingNet(self.M)
-
- self.trunk5 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask2 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
- if USE_VR_MODEL:
- self.scaling5 = ScalingNet(self.M)
-
- # VR_MODEL --> single hyper model
-
- if USE_MULTI_HYPER:
- if M == 192: # low and middle bitrate range
- self.hyper_1_enc = h_analysisTransform(M, [384, 1536, 512, 256], [1, 1, 1])
- self.hyper_2_enc = h_analysisTransform(256, [64*4, 64*4*4, 32*4, 128], [1, 1, 1])
- elif M == 256: # high bitrate range
- self.hyper_1_enc = h_analysisTransform(M, [384*2, 1536*2, 512*2, 256], [1, 1, 1])
- self.hyper_2_enc = h_analysisTransform(64*4, [64*4*2, 64*4*4*2, 32*4*2, 128], [1, 1, 1])
- else:
- # UNDEFINED MODEL HYPER PARAMETER
- raise Exception("UNDEFINED MODEL HYPER PARAMETER, M should be 192 or 256, but current M is "+str(M))
- else:
- self.trunk6 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
- self.trunk7 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk8 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask3 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
- self.conv2 = nn.Conv2d(self.M, self.N2, 3, 1, 1)
-
-
- def forward(self, x, lambda_rd=None):
- x1 = self.conv1(x)
- x2 = self.trunk1(x1)
- if USE_VR_MODEL and (lambda_rd is not None):
- x2 = self.scaling1(x2, lambda_rd)
- x3 = self.trunk2(x2)+x2
- x3 = self.down1(x3)
- if USE_VR_MODEL and (lambda_rd is not None):
- x3 = self.scaling2(x3, lambda_rd)
- x4 = self.trunk3(x3)
- if USE_VR_MODEL and (lambda_rd is not None):
- x4 = self.scaling3(x4, lambda_rd)
- x5 = self.trunk4(x4)
- if USE_VR_MODEL and (lambda_rd is not None):
- x5 = self.scaling4(x5, lambda_rd)
- x6 = self.trunk5(x5)*f.sigmoid(self.mask2(x5)) + x5
- if USE_VR_MODEL and (lambda_rd is not None):
- x6 = self.scaling5(x6, lambda_rd)
-
- if USE_MULTI_HYPER:
- x7 = self.hyper_1_enc(x6) # hyper_1
- x8 = self.hyper_2_enc(x7) # hyper_2
- return x6, x7, x8
- else:
- # hyper
- x7 = self.trunk6(x6)
- x8 = self.trunk7(x7)
- x9 = self.trunk8(x8)*f.sigmoid(self.mask3(x8)) + x8
- x10 = self.conv2(x9)
- return x6, x10
-
-
- class Hyper_Dec(nn.Module):
- def __init__(self, N2, M):
- super(Hyper_Dec, self).__init__()
-
- self.N2 = N2
- self.M = M
- self.conv1 = nn.Conv2d(self.N2, M, 3, 1, 1)
- self.trunk1 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
- self.trunk2 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
-
- def forward(self, xq2):
-
- x1 = self.conv1(xq2)
- x2 = self.trunk1(x1) * f.sigmoid(self.mask1(x1)) + x1
- x3 = self.trunk2(x2)
- x4 = self.trunk3(x3)
-
- return x4
-
-
- class Dec(nn.Module):
- def __init__(self, input_features, N1, M, M1):
- super(Dec, self).__init__()
-
- self.N1 = N1
- self.M = M
- self.M1 = M1
- self.input = input_features
-
- if USE_VR_MODEL:
- self.scaling1 = ScalingNet(self.M)
- self.trunk1 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
- self.up1 = nn.ConvTranspose2d(M, M, 5, 2, 2, 1)
- if USE_VR_MODEL:
- self.scaling2 = ScalingNet(self.M)
- self.trunk2 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
- if USE_VR_MODEL:
- self.scaling3 = ScalingNet(self.M)
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.ConvTranspose2d(M, 2*self.M1, 5, 2, 2, 1))
- if USE_VR_MODEL:
- self.scaling4 = ScalingNet(2 * self.M1)
- self.trunk4 = nn.Sequential(ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1))
- self.mask2 = nn.Sequential(Non_local_Block(2*self.M1, self.M1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), nn.Conv2d(2*self.M1, 2*self.M1, 1, 1, 0))
- if USE_VR_MODEL:
- self.scaling5 = ScalingNet(2 * self.M1)
- self.trunk5 = nn.Sequential(nn.ConvTranspose2d(2*M1, M1, 5, 2, 2, 1), ResBlock(self.M1, self.M1, 3, 1, 1), ResBlock(self.M1, self.M1, 3, 1, 1),
- ResBlock(self.M1, self.M1, 3, 1, 1))
-
- self.conv1 = nn.Conv2d(self.M1, self.input, 5, 1, 2)
-
- def forward(self, x, lambda_rd=None):
- if USE_VR_MODEL and (lambda_rd is not None):
- x = self.scaling1(x, lambda_rd)
- x1 = self.trunk1(x)*f.sigmoid(self.mask1(x))+x
- x1 = self.up1(x1)
- if USE_VR_MODEL and (lambda_rd is not None):
- x1 = self.scaling2(x1, lambda_rd)
- x2 = self.trunk2(x1)
- if USE_VR_MODEL and (lambda_rd is not None):
- x2 = self.scaling3(x2, lambda_rd)
- x3 = self.trunk3(x2)
- if USE_VR_MODEL and (lambda_rd is not None):
- x3 = self.scaling4(x3, lambda_rd)
- x4 = self.trunk4(x3)+x3
- if USE_VR_MODEL and (lambda_rd is not None):
- x4 = self.scaling5(x4, lambda_rd)
- #print (x4.size())
- x5 = self.trunk5(x4)
- output = self.conv1(x5)
- return output
-
-
- class Image_coding(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
- #input_features = 3, N1 = 192, N2 = 128, M = 192, M1 = 96
- super(Image_coding, self).__init__()
- self.N1 = N1
- self.encoder = Enc(input_features, N1, N2, M, M1)
- self.factorized_entropy_func = Entropy_bottleneck(N2)
- self.hyper_dec = Hyper_Dec(N2, M)
- self.p = P_Model(M)
- self.gaussin_entropy_func = Distribution_for_entropy()
- self.decoder = Dec(input_features, N1, M, M1)
-
- def add_noise(self, x):
- noise = np.random.uniform(-0.5, 0.5, x.size())
- noise = torch.Tensor(noise).cuda()
- return x + noise
-
- def forward(self, x, if_training, lambda_rd=None):
- x1, x2 = self.encoder(x, lambda_rd)
- xq2, xp2 = self.factorized_entropy_func(x2, if_training)
- x3 = self.hyper_dec(xq2)
- hyper_dec = self.p(x3)
- if if_training == 0:
- xq1 = self.add_noise(x1)
- elif if_training == 1:
- xq1 = UniverseQuant.apply(x1)
- else:
- xq1 = torch.round(x1)
- xp1 = self.gaussin_entropy_func(xq1, hyper_dec)
-
- output = self.decoder(xq1, lambda_rd)
-
- return [output, xp1, xp2, xq1, hyper_dec]
-
- class Image_coding_lmd(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
- #input_features = 3, N1 = 192, N2 = 128, M = 192, M1 = 96
- super(Image_coding_lmd, self).__init__()
- self.N1 = N1
- self.encoder = Enc(input_features, N1, N2, M, M1)
- self.factorized_entropy_func = Entropy_bottleneck(N2)
- self.hyper_dec = Hyper_Dec(N2, M)
- self.p = P_Model(M)
- self.decoder = Dec(input_features, N1, M, M1)
- self.modnet_main = modnet(256)
- self.modnet_hyper = modnet(192)
- self.context = Weighted_Gaussian(M)
-
- def add_noise(self, x):
- noise = np.random.uniform(-0.5, 0.5, x.size())
- noise = torch.Tensor(noise).cuda()
- return x + noise
-
- def forward(self, x, if_training, lambda_rd=None):
- b = x.size()[0]
- lmd_set = [i + 1 for i in range(8)] + [4 * i + 12 for i in range(14)] + [8 * i + 80 for i in range(23)]
- random_index = np.array(np.random.rand(b) * 45, dtype=np.int)
-
- random_lambda = [lmd_set[idx] for idx in random_index]
- random_lambda = np.array(random_lambda, dtype=np.float32)
-
- lmd = torch.from_numpy(random_lambda).cuda()
- lmd = torch.reshape(lmd, (b, 1)).cuda()
-
- x1, x2 = self.encoder(x, lambda_rd)
- x1,rou,rou0 = self.modnet_main(x1,lmd) #由于有context model在,mask必须先于context,否则编解码不匹配
- x2,_,_ = self.modnet_hyper(x2,lmd) #先mask掉hyper,再以其为输入计算其概率,则fac可不用改,且不同lmd对应hyper码率可变
- xq2, xp2 = self.factorized_entropy_func(x2, if_training)
-
- x3 = self.hyper_dec(xq2)
- hyper_dec = self.p(x3)
- if if_training == 0:
- xq1 = self.add_noise(x1)
- elif if_training == 1:
- xq1 = UniverseQuant.apply(x1)
- else:
- xq1 = torch.round(x1)
-
- xp1, _ = self.context(xq1, hyper_dec)
- output = self.decoder(xq1, lambda_rd)
-
- lmd = torch.squeeze(lmd, dim=1).cuda()
- delta = (output - x) ** 2
- delta = delta.view(b, -1)
- batch_mse = torch.mean(delta, dim=1, keepdim=False).cuda()
- dloss = torch.mean(lmd * batch_mse)
-
- bits = torch.sum(torch.log(xp1)) / (-np.log(2)) + torch.sum(torch.log(xp2)) / (-np.log(2))
-
- return dloss, bits, rou, rou0
-
- class Image_coding_multi_hyper(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
- #input_features = 3, N1 = 192, N2 = 128, M = 192, M1 = 96
- super(Image_coding_multi_hyper, self).__init__()
- self.N1 = N1
- self.encoder = Enc(input_features, N1, N2, M, M1)
- self.factorized_entropy_func = Entropy_bottleneck(128)
-
- if M == 192:
- self.hyper_1_dec = h_synthesisTransform(256, [768, 768, 768, M], [1, 1, 1])
- self.hyper_2_dec = h_synthesisTransform(128, [64*4, 64*4, 64*4, 64*4], [1, 1, 1])
- elif M == 256:
- self.hyper_1_dec = h_synthesisTransform(256, [768*2, 768*2, 768*2, M], [1, 1, 1])
- self.hyper_2_dec = h_synthesisTransform(128, [64*4*2, 64*4*2, 64*4*2, 64*4], [1, 1, 1])
-
- self.p = P_Model(M)
- self.p_2 = P_Model(256)
- self.factorized_entropy_func_for_hyper = Entropy_bottleneck(256) # STAGE 1 ONLY
- self.gaussin_entropy_func_for_hyper = Distribution_for_entropy()
- self.gaussin_entropy_func = Distribution_for_entropy()
- self.decoder = Dec(input_features, N1, M, M1)
-
- def add_noise(self, x):
- noise = np.random.uniform(-0.5, 0.5, x.size())
- noise = torch.Tensor(noise).cuda()
- return x + noise
-
- def forward(self, x, if_training, stage=-1):
- x1, x2, x3 = self.encoder(x)
- if stage == 1:
- xq2, xp2 = self.factorized_entropy_func_for_hyper(x2, if_training)
- else:
- xq3, xp3 = self.factorized_entropy_func(x3, if_training)
-
- x4 = self.hyper_2_dec(xq3)
-
- hyper_2_dec = self.p_2(x4)
-
- if if_training == 0:
- xq2 = self.add_noise(x2)
- elif if_training == 1:
- xq2 = UniverseQuant.apply(x2)
- else:
- xq2 = torch.round(x2)
-
- xp2 = self.gaussin_entropy_func_for_hyper(xq2, hyper_2_dec)
-
- x5 = self.hyper_1_dec(xq2)
- hyper_dec = self.p(x5)
-
- if if_training == 0:
- xq1 = self.add_noise(x1)
- elif if_training == 1:
- xq1 = UniverseQuant.apply(x1)
- else:
- xq1 = torch.round(x1)
- xp1 = self.gaussin_entropy_func(xq1, hyper_dec)
-
- output = self.decoder(xq1)
-
-
- if stage == 1:
- return output, xp1, xp2, xq1, hyper_dec
- else:
- return output, xp1, xp2, xq1, hyper_dec, xp3, xq3
-
-
- class Image_coding_multi_hyper_res(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
- # input_features = 3, N1 = 192, N2 = 128, M = 192, M1 = 96
- super(Image_coding_multi_hyper_res, self).__init__()
- self.N1 = N1
- self.encoder = Enc(input_features, N1, N2, M, M1)
- self.factorized_entropy_func = Entropy_bottleneck(128)
-
- if M == 192:
- self.hyper_1_dec = h_synthesisTransform(256, [768, 768, 768, M], [1, 1, 1])
- self.hyper_2_dec = h_synthesisTransform(128, [64 * 4, 64 * 4, 64 * 4, 64 * 4], [1, 1, 1])
- elif M == 256:
- self.hyper_1_dec = h_synthesisTransform(256, [768 * 2, 768 * 2, 768 * 2, M], [1, 1, 1])
- self.hyper_2_dec = h_synthesisTransform(128, [64 * 4 * 2, 64 * 4 * 2, 64 * 4 * 2, 64 * 4], [1, 1, 1])
-
- self.p = P_Model(M)
- self.p_2 = P_Model(256)
- self.Y = P_NN(M, 2)
- self.Y_2 = P_NN(256, 2)
- self.factorized_entropy_func_for_hyper = Entropy_bottleneck(256) # STAGE 1 ONLY
- self.gaussin_entropy_func_for_hyper = Distribution_for_entropy()
- self.gaussin_entropy_func = Distribution_for_entropy()
- self.decoder = Dec(input_features, N1, M, M1)
-
- def add_noise(self, x):
- noise = np.random.uniform(-0.5, 0.5, x.size())
- noise = torch.Tensor(noise).cuda()
- return x + noise
-
- def forward(self, x, if_training, stage=-1):
- x1, x2, x3 = self.encoder(x)
-
- xq3, xp3 = self.factorized_entropy_func(x3, if_training)
-
- x4 = self.hyper_2_dec(xq3)
-
- hyper_2_dec = self.p_2(x4)
-
- x2_predict = self.Y_2(x4)
- x2_res = x2 - x2_predict
-
- if if_training == 0:
- xq2_res = self.add_noise(x2_res)
- elif if_training == 1:
- xq2_res = UniverseQuant.apply(x2_res)
- else:
- xq2_res = torch.round(x2_res)
-
- xp2_res = self.gaussin_entropy_func_for_hyper(xq2_res, hyper_2_dec)
-
- xq2 = xq2_res + x2_predict
-
- x5 = self.hyper_1_dec(xq2)
- hyper_dec = self.p(x5)
-
- x1_predict = self.Y(x5)
- x1_res = x1 - x1_predict
-
- if if_training == 0:
- xq1_res = self.add_noise(x1_res)
- elif if_training == 1:
- xq1_res = UniverseQuant.apply(x1_res)
- else:
- xq1_res = torch.round(x1_res)
- xp1_res = self.gaussin_entropy_func(xq1_res, hyper_dec)
-
- xq1 = xq1_res + x1_predict
-
- output = self.decoder(xq1)
-
- return output, xp1_res, xp2_res, xq1_res, xq1, hyper_dec, xp3, xq3
-
-
- class UniverseQuant(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x):
- b = np.random.uniform(-1, 1)
- #b = 0
- uniform_distribution = Uniform(-0.5*torch.ones(x.size())
- * (2**b), 0.5*torch.ones(x.size())*(2**b)).sample().cuda()
- return torch.round(x+uniform_distribution)-uniform_distribution
-
- @staticmethod
- def backward(ctx, g):
-
- return g
|