#50 modify test

Merged
Manson merged 1 commits from Manson/MSAdapter:add_bn_test into master 1 year ago
  1. +60
    -57
      testing/layers/test_batchnorm.py
  2. +32
    -40
      testing/layers/test_linear.py

+ 60
- 57
testing/layers/test_batchnorm.py View File

@@ -1,68 +1,71 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import torch

from mindspore import context
import mindspore as ms

from ms_adapter.pytorch.nn import Module
from ms_adapter.pytorch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
from ms_adapter.pytorch import tensor
import numpy as np

from mindspore import context
import mindspore as ms

context.set_context(mode=ms.GRAPH_MODE)


class BnModel1d(Module):
def __init__(self):
super(BnModel1d, self).__init__()
self.bn1 = BatchNorm1d(num_features=32)
self.bn2 = BatchNorm1d(32, affine=False)

def forward(self, inputs):
x = self.bn1(inputs)
x = self.bn2(x)
return x

class BnModel2d(Module):
def __init__(self):
super(BnModel2d, self).__init__()
self.bn1 = BatchNorm2d(num_features=32)
self.bn2 = BatchNorm2d(32, affine=False)

def forward(self, inputs):
x = self.bn1(inputs)
x = self.bn2(x)
return x

class BnModel3d(Module):
def __init__(self):
super(BnModel3d, self).__init__()
self.bn1 = BatchNorm3d(num_features=32)
self.bn2 = BatchNorm3d(32, affine=False)

def forward(self, inputs):
x = self.bn1(inputs)
x = self.bn2(x)
return x

model1d = BnModel1d()
model2d = BnModel2d()
model3d = BnModel3d()
model1d.train()
model2d.train()
model3d.train()

for n, v in model2d.named_parameters():
print(n, v.shape)

inputs1d = tensor(np.ones(shape=(5, 32)), ms.float32)
output1d = model1d(inputs1d)
print(output1d.shape)

inputs2d = tensor(np.ones(shape=(5, 32, 5, 5)), ms.float32)
output2d = model2d(inputs2d)
print(output2d.shape)

inputs3d = tensor(np.ones(shape=(5, 32, 5, 5, 5)), ms.float32)
output3d = model3d(inputs3d)
print(output3d.shape)
def test_bn():
class BnModel1d(Module):
def __init__(self):
super(BnModel1d, self).__init__()
self.bn1 = BatchNorm1d(num_features=32)
self.bn2 = BatchNorm1d(32, affine=False)

def forward(self, inputs):
x = self.bn1(inputs)
x = self.bn2(x)
return x

class BnModel2d(Module):
def __init__(self):
super(BnModel2d, self).__init__()
self.bn1 = BatchNorm2d(num_features=32)
self.bn2 = BatchNorm2d(32, affine=False)

def forward(self, inputs):
x = self.bn1(inputs)
x = self.bn2(x)
return x

class BnModel3d(Module):
def __init__(self):
super(BnModel3d, self).__init__()
self.bn1 = BatchNorm3d(num_features=32)
self.bn2 = BatchNorm3d(32, affine=False)

def forward(self, inputs):
x = self.bn1(inputs)
x = self.bn2(x)
return x

model1d = BnModel1d()
model2d = BnModel2d()
model3d = BnModel3d()
model1d.train()
model2d.train()
model3d.train()

inputs1d = tensor(np.ones(shape=(5, 32)), ms.float32)
output1d = model1d(inputs1d)
assert output1d.shape == (5, 32)

inputs2d = tensor(np.ones(shape=(5, 32, 5, 5)), ms.float32)
output2d = model2d(inputs2d)
assert output2d.shape == (5, 32, 5, 5)

inputs3d = tensor(np.ones(shape=(5, 32, 5, 5, 5)), ms.float32)
output3d = model3d(inputs3d)
assert output3d.shape == (5, 32, 5, 5, 5)


test_bn()

+ 32
- 40
testing/layers/test_linear.py View File

@@ -11,43 +11,35 @@ import mindspore as ms
context.set_context(mode=ms.GRAPH_MODE)


class LinearModel(Module):
def __init__(self):
super(LinearModel, self).__init__()
self.line1 = Linear(in_features=32, out_features=64)
self.line2 = Linear(in_features=64, out_features=128, bias=False)
self.line3 = Linear(in_features=128, out_features=10)

def forward(self, inputs):
x = self.line1(inputs)
x = self.line2(x)
x = self.line3(x)
return x

model = LinearModel()
model.train()

for n, v in model.named_parameters():
print(n, v.shape)

from ms_adapter.pytorch.nn import init

# #define the initial function to init the layer's parameters for the network
# def weight_init(m):
# for m in m.modules():
# if isinstance(m, Linear):
# m.weight.data.normal_(0,0.01)
# if m.has_bias:
# m.bias.data.zero_()
# weight_init(model)

def weight_init(m):
if isinstance(m, Linear):
m.weight.data.normal_(0,0.01)
if m.has_bias:
m.bias.data.zero_()
model.apply(weight_init)

inputs = tensor(np.ones(shape=(5, 32)), ms.float32)
output = model(inputs)
print(output)
def test_linear_model():
class LinearModel(Module):
def __init__(self):
super(LinearModel, self).__init__()
self.line1 = Linear(in_features=32, out_features=64)
self.line2 = Linear(in_features=64, out_features=128, bias=False)
self.line3 = Linear(in_features=128, out_features=10)

def forward(self, inputs):
x = self.line1(inputs)
x = self.line2(x)
x = self.line3(x)
return x

model = LinearModel()
model.train()

#for n, v in model.named_parameters():
# print(n, v.shape)
def weight_init(m):
if isinstance(m, Linear):
m.weight.data.normal_(0, 0.01)
if m.has_bias:
m.bias.data.zero_()
model.apply(weight_init)

inputs = tensor(np.ones(shape=(5, 32)), ms.float32)
output = model(inputs)
assert output.shape == (5, 10)


test_linear_model()

Loading…
Cancel
Save