|
- # !/usr/bin/env python
- # coding=UTF-8
- """
- @Author: WEN Hao
- @LastEditors: WEN Hao
- @Description:
- @Date: 2021-09-08
- @LastEditTime: 2022-04-01
-
- 预置LSTM分类模型
- """
-
- import json
- import glob
- import os
- from typing import NoReturn, Optional, List
-
- import torch
- from torch import nn as nn
-
- from ._glove_embedding_layer import GloveEmbeddingLayer
- from ..pytorch_model import PyTorchNLPVictimModel
- from ..Tokenizers import GloveTokenizer
- from utils.misc import nlp_cache_dir
- from utils._download_data import download_if_needed
-
-
- __all__ = [
- "VictimLSTMForClassification",
- ]
-
-
- class VictimLSTMForClassification(PyTorchNLPVictimModel):
- """
- 2分类或5分类模型,英文
- """
-
- __name__ = "VictimLSTMForClassification"
-
- def __init__(self, path: Optional[str] = None) -> NoReturn:
- """ """
- self._path = path or os.path.join(nlp_cache_dir, "lstm-imdb")
- if not os.path.exists(self._path):
- # raise ValueError("暂不支持在线下载模型")
- download_if_needed(
- uri="lstm-imdb",
- source="aitesting",
- dst_dir=nlp_cache_dir,
- extract=True,
- )
- else:
- model = _LSTMForClassification.from_pretrained(os.path.basename(self._path))
- super().__init__(model, model.tokenizer)
-
- @property
- def path(self) -> str:
- return self._path
-
- def extra_repr_keys(self) -> List[str]:
- return ["path"]
-
-
- class _LSTMForClassification(nn.Module):
- """A long short-term memory neural network for text classification.
-
- We use different versions of this network to pretrain models for
- text classification.
- """
-
- __name__ = "_LSTMForClassification"
- _BUILTIN_MODELS = [
- "lstm-sst2",
- "lstm-imdb",
- ]
-
- def __init__(
- self,
- hidden_size: int = 150,
- depth: int = 1,
- dropout: float = 0.3,
- num_labels: int = 2,
- max_seq_length: int = 128,
- model_path: Optional[str] = None,
- emb_layer_trainable: bool = True,
- ) -> NoReturn:
- """ """
- _config = {
- "hidden_size": hidden_size,
- "depth": depth,
- "dropout": dropout,
- "num_labels": num_labels,
- "max_seq_length": max_seq_length,
- "model_path": None,
- "emb_layer_trainable": emb_layer_trainable,
- }
- if model_path:
- self = VictimLSTMForClassification.from_pretrained(**_config)
- self._config["model_path"] = model_path
- self._config["architectures"] = self.__name__
- return
-
- super().__init__()
- self._config = _config.copy()
- self._config["architectures"] = self.__name__
- if depth <= 1:
- # Fix error where we ask for non-zero dropout with only 1 layer.
- # nn.module.RNN won't add dropout for the last recurrent layer,
- # so if that's all we have, this will display a warning.
- dropout = 0
- self.drop = nn.Dropout(dropout)
- self.emb_layer_trainable = emb_layer_trainable
- self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable)
- self.word2id = self.emb_layer.word2id
- self.encoder = nn.LSTM(
- input_size=self.emb_layer.n_d,
- hidden_size=hidden_size // 2,
- num_layers=depth,
- dropout=dropout,
- bidirectional=True,
- )
- d_out = hidden_size
- self.out = nn.Linear(d_out, num_labels)
- self.tokenizer = GloveTokenizer(
- word_id_map=self.word2id,
- unk_token_id=self.emb_layer.oovid,
- pad_token_id=self.emb_layer.padid,
- max_length=max_seq_length,
- )
-
- def save_pretrained(self, output_path: str) -> NoReturn:
- if not os.path.exists(output_path):
- os.makedirs(output_path)
- state_dict = {k: v.cpu() for k, v in self.state_dict().items()}
- torch.save(
- state_dict,
- os.path.join(output_path, "pytorch_model.bin"),
- )
- with open(os.path.join(output_path, "config.json"), "w") as f:
- json.dump(self._config, f)
-
- @classmethod
- def from_pretrained(cls, name_or_path: str) -> "_LSTMForClassification":
- """Load trained LSTM model by name or from path."""
- if name_or_path in _LSTMForClassification._BUILTIN_MODELS:
- uri = f"models_v2/classification/{name_or_path.replace('-', '/')}"
- # path = os.path.join(nlp_cache_dir, *(uri.strip("/").split("/")))
- path = os.path.join(nlp_cache_dir, name_or_path)
- if not os.path.exists(path):
- download_if_needed(
- uri, source="textattack", dst_dir=nlp_cache_dir, extract=True
- )
- else:
- path = name_or_path
- if not os.path.exists(path):
- raise ValueError(f"路径 {path} 不存在")
-
- config_path = os.path.join(path, "config.json")
-
- if os.path.exists(config_path):
- with open(config_path, "r") as f:
- config = json.load(f)
- else:
- # Default config
- config = {
- "architectures": "LSTMForClassification",
- "hidden_size": 150,
- "depth": 1,
- "dropout": 0.3,
- "num_labels": 2,
- "max_seq_length": 128,
- "model_path": None,
- "emb_layer_trainable": True,
- }
- del config["architectures"]
- model = cls(**config)
- state_dict = load_cached_state_dict(path)
- model.load_state_dict(state_dict)
- return model
-
- def forward(self, _input: torch.Tensor) -> torch.Tensor:
- # ensure RNN module weights are part of single contiguous chunk of memory
- self.encoder.flatten_parameters()
-
- emb = self.emb_layer(_input.t())
- emb = self.drop(emb)
-
- output, hidden = self.encoder(emb)
- output = torch.max(output, dim=0)[0]
-
- output = self.drop(output)
- pred = self.out(output)
- return pred
-
- def get_input_embeddings(self) -> nn.Module:
- return self.emb_layer.embedding
-
-
- def load_cached_state_dict(model_folder_path: str) -> dict:
- # Take the first model matching the pattern *model.bin.
- model_path_list = glob.glob(os.path.join(model_folder_path, "*model.bin"))
- if not model_path_list:
- raise FileNotFoundError(
- f"model.bin not found in model folder {model_folder_path}."
- )
- model_path = model_path_list[0]
- state_dict = torch.load(model_path, map_location=torch.device("cpu"))
- return state_dict
|