Browse Source

fix stride slice error

master
lvyufeng 2 weeks ago
parent
commit
550a6db211
8 changed files with 55 additions and 22 deletions
  1. +2
    -0
      .gitignore
  2. +6
    -1
      eval.py
  3. +4
    -2
      scripts/run_eval.sh
  4. +3
    -2
      scripts/run_standalone_train.sh
  5. +4
    -1
      src/config.py
  6. +6
    -3
      src/rnns.py
  7. +21
    -12
      src/seq2seq.py
  8. +9
    -1
      train.py

+ 2
- 0
.gitignore View File

@@ -134,3 +134,5 @@ train/
eval/
output.txt*
target.txt*

.vscode/

+ 6
- 1
eval.py View File

@@ -16,6 +16,7 @@
import os
import argparse
import mindspore.common.dtype as mstype
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
@@ -30,7 +31,7 @@ def run_gru_eval():
Transformer evaluation.
"""
parser = argparse.ArgumentParser(description='GRU eval')
parser.add_argument("--device_target", type=str, default="GPU",
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented, default is Ascend")
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
@@ -41,6 +42,10 @@ def run_gru_eval():

context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
device_id=args.device_id, save_graphs=False)
if args.device_target == "GPU":
if config.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
config.compute_type = mstype.float32
mindrecord_file = args.dataset_path
if not os.path.exists(mindrecord_file):
print("dataset file {} not exists, please check!".format(mindrecord_file))


+ 4
- 2
scripts/run_eval.sh View File

@@ -20,9 +20,11 @@ exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export DEVICE_ID=4
export RANK_ID=0
export RANK_SIZE=1
export DEVICE_TARGET="Ascend"

get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
@@ -54,5 +56,5 @@ cp -r ../src ./eval
cd ./eval || exit
echo "start eval for device $DEVICE_ID"
env > env.log
python eval.py --ckpt_file=$CKPT_FILE --dataset_path=$DATASET_PATH &> log &
python eval.py --device_target=$DEVICE_TARGET --ckpt_file=$CKPT_FILE --dataset_path=$DATASET_PATH --device_id=$DEVICE_ID &> log &
cd ..

+ 3
- 2
scripts/run_standalone_train.sh View File

@@ -20,9 +20,10 @@ exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export DEVICE_ID=4
export RANK_ID=0
export RANK_SIZE=1
export DEVICE_TARGET="Ascend"
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
@@ -47,5 +48,5 @@ cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --dataset_path=$DATASET_PATH &> log &
python train.py --device_target=$DEVICE_TARGET --device_id=$DEVICE_ID --dataset_path=$DATASET_PATH &> log &
cd ..

+ 4
- 1
src/config.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""GRU config"""
from easydict import EasyDict
import mindspore.common.dtype as mstype

config = EasyDict({
"batch_size": 16,
@@ -37,5 +38,7 @@ config = EasyDict({
'scale_factor': 2,
'scale_window': 2000,
"warmup_ratio": 1/3.0,
"teacher_force_ratio": 0.5
"teacher_force_ratio": 0.5,
"compute_type": mstype.float16,
"dtype": mstype.float32
})

+ 6
- 3
src/rnns.py View File

@@ -43,7 +43,9 @@ class DynamicRNN(nn.Cell):
t = 0
h = h_0
while t < time_step:
h = self.cell(x[t], h, w_ih, w_hh, b_ih, b_hh)
x_t = x[t:t+1:1]
x_t = P.Squeeze(0)(x_t)
h = self.cell(x_t, h, w_ih, w_hh, b_ih, b_hh)
if self.is_lstm:
outputs.append(h[0])
else:
@@ -70,7 +72,9 @@ class DynamicRNN(nn.Cell):
state_t = h_t
t = 0
while t < time_step:
h_t = self.cell(x[t], state_t, w_ih, w_hh, b_ih, b_hh)
x_t = x[t:t+1:1]
x_t = P.Squeeze(0)(x_t)
h_t = self.cell(x_t, state_t, w_ih, w_hh, b_ih, b_hh)
seq_cond = seq_length > t
if self.is_lstm:
state_t_0 = P.Select()(seq_cond, h_t[0], state_t[0])
@@ -200,7 +204,6 @@ class RNNBase(nn.Cell):
else:
h_n = P.Concat(0)(h_n)
return output, h_n.view(h.shape)
return x, h
def _stacked_dynamic_rnn(self, x, h, seq_length):
"""stacked mutil_layer dynamic_rnn"""


+ 21
- 12
src/seq2seq.py View File

@@ -30,8 +30,8 @@ class Attention(nn.Cell):
super(Attention, self).__init__()
self.text_len = config.max_length
self.attn = nn.Dense(in_channels=config.hidden_size * 3,
out_channels=config.hidden_size).to_float(mstype.float32)
self.fc = nn.Dense(config.hidden_size, 1, has_bias=False).to_float(mstype.float32)
out_channels=config.hidden_size).to_float(config.compute_type)
self.fc = nn.Dense(config.hidden_size, 1, has_bias=False).to_float(config.compute_type)
self.expandims = P.ExpandDims()
self.tanh = P.Tanh()
self.softmax = P.Softmax()
@@ -40,6 +40,9 @@ class Attention(nn.Cell):
self.concat = P.Concat(axis=2)
self.squeeze = P.Squeeze(axis=2)
self.cast = P.Cast()
self.dtype = config.dtype
self.compute_type = config.compute_type

def construct(self, hidden, encoder_outputs):
'''
Attention construction
@@ -59,9 +62,9 @@ class Attention(nn.Cell):
energy = self.tanh(out)
attention = self.fc(energy)
attention = self.squeeze(attention)
attention = self.cast(attention, mstype.float32)
attention = self.cast(attention, self.dtype)
attention = self.softmax(attention)
attention = self.cast(attention, mstype.float32)
attention = self.cast(attention, self.compute_type)
return attention

class Encoder(nn.Cell):
@@ -77,8 +80,8 @@ class Encoder(nn.Cell):
self.vocab_size = config.src_vocab_size
self.embedding_size = config.encoder_embedding_size
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
self.rnn = GRU(input_size=self.embedding_size, hidden_size=self.hidden_size, bidirectional=True).to_float(mstype.float32)
self.fc = nn.Dense(2*self.hidden_size, self.hidden_size).to_float(mstype.float32)
self.rnn = GRU(input_size=self.embedding_size, hidden_size=self.hidden_size, bidirectional=True).to_float(config.compute_type)
self.fc = nn.Dense(2*self.hidden_size, self.hidden_size).to_float(config.compute_type)
self.shape = P.Shape()
self.transpose = P.Transpose()
self.p = P.Print()
@@ -86,6 +89,8 @@ class Encoder(nn.Cell):
self.text_len = config.max_length
self.squeeze = P.Squeeze(axis=0)
self.tanh = P.Tanh()
self.concat = P.Concat(2)
self.dtype = config.dtype

def construct(self, src):
'''
@@ -100,7 +105,7 @@ class Encoder(nn.Cell):
'''
embedded = self.embedding(src)
embedded = self.transpose(embedded, (1, 0, 2))
embedded = self.cast(embedded, mstype.float32)
embedded = self.cast(embedded, self.dtype)
output, hidden = self.rnn(embedded)
hidden = self.transpose(hidden, (1, 0, 2))
hidden = hidden.view(hidden.shape[0], -1)
@@ -121,7 +126,7 @@ class Decoder(nn.Cell):
self.vocab_size = config.trg_vocab_size
self.embedding_size = config.decoder_embedding_size
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
self.rnn = GRU(input_size=self.embedding_size + self.hidden_size*2, hidden_size=self.hidden_size).to_float(mstype.float32)
self.rnn = GRU(input_size=self.embedding_size + self.hidden_size*2, hidden_size=self.hidden_size).to_float(config.compute_type)
self.text_len = config.max_length
self.shape = P.Shape()
self.transpose = P.Transpose()
@@ -133,11 +138,13 @@ class Decoder(nn.Cell):
self.log_softmax = P.LogSoftmax(axis=1)
weight, bias = dense_default_state(self.embedding_size+self.hidden_size*3, self.vocab_size)
self.fc = nn.Dense(self.embedding_size+self.hidden_size*3, self.vocab_size,
weight_init=weight, bias_init=bias).to_float(mstype.float32)
weight_init=weight, bias_init=bias).to_float(config.compute_type)
self.attention = Attention(config)
self.bmm = P.BatchMatMul()
self.dropout = nn.Dropout(0.7)
self.expandims = P.ExpandDims()
self.dtype = config.dtype

def construct(self, inputs, hidden, encoder_outputs):
'''
Decoder construction
@@ -153,21 +160,23 @@ class Decoder(nn.Cell):
'''
embedded = self.embedding(inputs)
embedded = self.transpose(embedded, (1, 0, 2))
embedded = self.cast(embedded, mstype.float32)
embedded = self.cast(embedded, self.dtype)
attn = self.attention(hidden, encoder_outputs)
attn = self.expandims(attn, 1)
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
weight = self.bmm(attn, encoder_outputs)
weight = self.transpose(weight, (1, 0, 2))
weight = self.cast(weight, self.dtype)
emd_con = self.concat((embedded, weight))
# print(emd_con.shape)
output, hidden = self.rnn(emd_con)
output = self.cast(output, self.dtype)
out = self.concat((embedded, output, weight))
out = self.squeeze(out)
hidden = self.squeeze(hidden)
prediction = self.fc(out)
prediction = self.dropout(prediction)
prediction = self.cast(prediction, mstype.float32)
prediction = self.cast(prediction, mstype.float32)
prediction = self.cast(prediction, self.dtype)
pred_prob = self.log_softmax(prediction)
pred_prob = self.expandims(pred_prob, 0)
return pred_prob, hidden


+ 9
- 1
train.py View File

@@ -17,8 +17,10 @@ import os
import time
import argparse
import ast
import mindspore.common.dtype as mstype
from mindspore.context import ParallelMode
from mindspore import context
from mindspore import log as logger
from mindspore.communication.management import init
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model
@@ -33,6 +35,8 @@ from src.lr_schedule import dynamic_lr
set_seed(1)

parser = argparse.ArgumentParser(description="GRU training")
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented, default is Ascend")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset path")
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained file path.")
@@ -43,7 +47,11 @@ parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoin
parser.add_argument('--outputs_dir', type=str, default='./', help='Checkpoint save location. Default: outputs/')
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args.device_id, save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id, save_graphs=False)
if args.device_target == "GPU":
if config.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
config.compute_type = mstype.float32

def get_ms_timestamp():
t = time.time()


Loading…
Cancel
Save