|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # @File : tryPaddle.py
- # @Date : 2021/11/2
- # @Desc :
- # @Author : Hou
- import os
- import numpy as np
- import argparse
-
- import paddle
- from paddlenlp.data import JiebaTokenizer, Stack, Tuple, Pad, Vocab
- from paddlenlp.datasets import load_dataset
- from model import BoWModel, BiLSTMAttentionModel, CNNModel, LSTMModel, GRUModel, RNNModel, SelfInteractiveAttention
-
-
- # yapf: disable
- parser = argparse.ArgumentParser(__doc__)
- parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="cpu", help="Select which device to train model, defaults to gpu.")
- parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number of a batch for training.")
- #parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate used to train.")
- parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The path to vocabulary.")
- parser.add_argument("--val_dir", type=str, default="./ChnSentiCorp/dev.tsv", help="The valid data file path.")
- parser.add_argument('--network', choices=['bow', 'lstm', 'bilstm', 'gru', 'bigru', 'rnn', 'birnn', 'bilstm_attn', 'cnn'],
- default="bilstm", help="Select which network to train, defaults to bilstm.")
-
- parser.add_argument("--model_path", type=str, default='./checkpoints', help="The path of model parameter to be loaded.")
- parser.add_argument("--checkpoint", type=str, default='final.pdparams', help="The checkpoint of model parameter to be loaded.")
- args = parser.parse_args()
- print(args)
- # yapf: enable
-
- model_path = args.model_path
- val_dir = args.val_dir
-
-
- def data_load():
- # dev_ds = load_dataset("chnsenticorp", data_files=[r"./ChnSentiCorp_a/dev.tsv"], splits=["dev"])
-
- dev_ds = load_dataset("chnsenticorp", data_files=[r"./ChnSentiCorp_a/test.tsv"], splits=["train"])
- print(len(dev_ds))
-
- print(dev_ds[0]['text'])
- print(dev_ds[0]['label'])
-
- print(dev_ds[5]['text'])
- print(dev_ds[5]['label'])
-
-
- def get_model_arch(vocab_size, num_classes, pad_token_id):
- """
- 初始化模型
- :param vocab_size:
- :param num_classes:
- :param pad_token_id:
- :return:
- """
- # Constructs the network.
- network = args.network.lower()
-
- if network == 'bow':
- model = BoWModel(vocab_size, num_classes, padding_idx=pad_token_id)
- elif network == 'bigru':
- model = GRUModel(
- vocab_size,
- num_classes,
- direction='bidirect',
- padding_idx=pad_token_id)
- elif network == 'bilstm':
- model = LSTMModel(
- vocab_size,
- num_classes,
- direction='bidirect',
- padding_idx=pad_token_id)
- elif network == 'bilstm_attn':
- lstm_hidden_size = 196
- attention = SelfInteractiveAttention(hidden_size=2 * lstm_hidden_size)
- model = BiLSTMAttentionModel(
- attention_layer=attention,
- vocab_size=vocab_size,
- lstm_hidden_size=lstm_hidden_size,
- num_classes=num_classes,
- padding_idx=pad_token_id)
- elif network == 'birnn':
- model = RNNModel(
- vocab_size,
- num_classes,
- direction='bidirect',
- padding_idx=pad_token_id)
- elif network == 'cnn':
- model = CNNModel(vocab_size, num_classes, padding_idx=pad_token_id)
- elif network == 'gru':
- model = GRUModel(
- vocab_size,
- num_classes,
- direction='forward',
- padding_idx=pad_token_id,
- pooling_type='max')
- elif network == 'lstm':
- model = LSTMModel(
- vocab_size,
- num_classes,
- direction='forward',
- padding_idx=pad_token_id,
- pooling_type='max')
- elif network == 'rnn':
- model = RNNModel(
- vocab_size,
- num_classes,
- direction='forward',
- padding_idx=pad_token_id,
- pooling_type='max')
- else:
- raise ValueError(
- "Unknown network: %s, it must be one of bow, lstm, bilstm, cnn, gru, bigru, rnn, birnn and bilstm_attn."
- % network)
- model = paddle.Model(model)
- return model
-
-
- def model_load(model):
- print("Model load")
- # define optimizer
- # optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=args.lr)
- optimizer = None
-
- # Defines loss and metric.
- criterion = paddle.nn.CrossEntropyLoss()
- metric = paddle.metric.Accuracy()
-
- model.prepare(optimizer, criterion, metric)
-
- # Loads model parameters.
- checkpoint_path = os.path.join(model_path, args.checkpoint)
-
- model.load(checkpoint_path)
- print("Loaded checkpoint from %s" % checkpoint_path)
- return model
-
-
- def show_model_params(model):
-
- for layer_name, values in model.network.state_dict().items():
- print("==========", layer_name)
- print(values)
- # for index in range(values):
- # print((f"{layer_name}_{index:02d}"))
-
-
- def _parameterNameValue(model):
- var_value = []
- for param_name, params in model.network.state_dict().items():
- var_value.append((f'{param_name}', params.numpy()))
- return var_value
-
-
- def _setModelPara(model, items):
-
- params_num = 0
- for param_name, params in items:
- if param_name in model.network.state_dict():
- model.network.state_dict()[param_name] = paddle.to_tensor(params)
- params_num += 1
- return params_num
-
-
- def verify_model_setting(model, var_name):
- # model.network.state_dict()[var_name] = paddle.to_tensor(np.array([0, 1], dtype=float)) # !!!不能直接修改参数字典中的值
- print("Update Value.....")
- params_dict = model.network.state_dict()
- print(params_dict[var_name])
- print(paddle.to_tensor(np.array([0, 1], dtype=float)))
-
- aimDType = params_dict[var_name].dtype
- print("DataType", aimDType)
-
- params_dict[var_name] = paddle.to_tensor(np.array([0.0, 1.0]), dtype=aimDType)
- print(params_dict[var_name])
-
- model.network.set_state_dict(params_dict)
- print(model.network.state_dict()[var_name])
- print("Updated Value!!!")
- return model
-
-
- if __name__ == "__main__":
- if not os.path.exists(model_path):
- os.makedirs(model_path)
- paddle.set_device(args.device.lower())
-
- # Loads vocab.
- vocab_local = Vocab.load_vocabulary(
- args.vocab_path, unk_token='[UNK]', pad_token='[PAD]')
- label_map = {0: 'negative', 1: 'positive'}
-
- # 获取模型
- model_arch = get_model_arch(len(vocab_local), len(label_map), vocab_local.to_indices('[PAD]'))
- print(f'get model success')
- # 加载参数
- model_trained = model_load(model_arch)
-
- # # show_model_params(model_trained)
- # value_list = _parameterNameValue(model_trained)
- # print(len(value_list))
- # print(value_list[:2])
- #
- # num = _setModelPara(model_arch, value_list)
- # print(num)
-
-
- # print("network========>")
- # print(model_trained.network)
- print("network state dict keys========>")
- print(model_trained.network.state_dict().keys())
-
- value_name = "output_layer.bias"
- print(model_trained.network.state_dict()[value_name])
- updated_model = verify_model_setting(model_trained, value_name)
- print(model_trained.network.state_dict()[value_name])
- print(updated_model.network.state_dict()[value_name])
-
-
- #
- # print("model network param========>")
- # print(model_trained.network.parameters())
- #
- # print("model param========>")
- # print(model_trained.parameters())
-
- print("Done.")
|