|
- import textattack
- import transformers
-
-
- def TextAttack(params):
- # Platform = params['platform']
- # taskId = params['taskId']
- # Queue = params['queue']
- Dataset = params["dataset"]
- # print('============'Dataset'=================')
- Model = params["model"]
- # Group = "NLP"
- # Method = "Attack"
- Attack_Method = params["attack_method"]["method"]
- Attack_Method_NUM = params["attack_method"]["sampleSize"]
- # Evaluation_Object = params['object']
-
- print(params)
-
- json_save_FLAG = False
- root_dir = '/data0/BigPlatform/ZJPlatform/002_NLP/001-Demo/Attack/TextAttack/' # OpenAttack项目存放的根目录
-
- print("Loading model...")
- if Model == 'ALBERT' and Dataset == 'Yelp-polarity':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/albert-base-v2-yelp-polarity")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/albert-base-v2-yelp-polarity")
- elif Model == 'Bert-base-uncased' and Dataset == 'Yelp-polarity':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/bert-base-uncased-yelp-polarity")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/bert-base-uncased-yelp-polarity")
- elif Model == 'BERT' and Dataset == 'RTE':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/bert-base-uncased-yelp-polarity")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/bert-base-uncased-yelp-polarity")
- elif Model == 'RoBERTa' and Dataset == 'RTE':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/roberta-base-RTE")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/roberta-base-RTE")
- # elif Model == 'XLNET' and Dataset == 'RTE':
- elif Model == 'BERT' and Dataset == 'SST':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/xlnet-base-cased-RTE")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/xlnet-base-cased-RTE")
- elif Model == 'ALBERT' and Dataset == 'MRPC':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/albert-base-v2-MRPC")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/albert-base-v2-MRPC")
- elif Model == 'DistilBert-base-cased' and Dataset == 'MRPC':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/distilbert-base-cased-MRPC")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/distilbert-base-cased-MRPC")
- elif Model == 'DistilBert-base-uncased' and Dataset == 'MRPC':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/distilbert-base-uncased-MRPC")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/distilbert-base-uncased-MRPC")
- elif Model == 'ALBERT' and Dataset == 'STS-B':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/albert-base-v2-STS-B")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/albert-base-v2-STS-B")
- elif Model == 'DistilBert-base-cased' and Dataset == 'STS-B':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/distilbert-base-cased-STS-B")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/distilbert-base-cased-STS-B")
- elif Model == 'DistilBert-base-uncased' and Dataset == 'STS-B':
- model = transformers.AutoModelForSequenceClassification.from_pretrained(
- "./model/distilbert-base-uncased-STS-B")
- tokenizer = transformers.AutoTokenizer.from_pretrained("./model/distilbert-base-uncased-STS-B")
-
- model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
-
-
- print('Attacking..')
- if Attack_Method == 'alzantot':
- attack = textattack.attack_recipes.genetic_algorithm_alzantot_2018.build(model_wrapper)
- elif Attack_Method == 'fast-alzantot':
- attack = textattack.attack_recipes.faster_genetic_algorithm_jia_2019.build(model_wrapper)
- elif Attack_Method == 'BAE':
- attack = textattack.attack_recipes.bae_garg_2019.build(model_wrapper)
- elif Attack_Method == 'BERT-attack':
- attack = textattack.attack_recipes.bert_attack_li_2020.build(model_wrapper)
- # elif Attack_Method == 'checklist':
- elif Attack_Method == 'GAN':
- attack = textattack.attack_recipes.CheckList2020.build(model_wrapper)
- elif Attack_Method == 'clare':
- attack = textattack.attack_recipes.clare_li_2020.build(model_wrapper)
- elif Attack_Method == 'seq2sick':
- attack = textattack.attack_recipes.seq2sick_cheng_2018_blackbox.build(model_wrapper)
-
-
- attack_args = textattack.AttackArgs(
- num_examples=Attack_Method_NUM,
- log_to_csv="AttackLog/log-{}-{}.csv".format(Model, Attack_Method),
- checkpoint_interval=5,
- checkpoint_dir="checkpoints",
- disable_stdout=True
- )
-
-
- print("Loading dataset...")
- if Dataset == 'Yelp-polarity':
- # label_map = {0: ['Negative', '情感消极'], 1: ['Positive', '情感积极']}
- dataset = textattack.datasets.HuggingFaceDataset("yelp_polarity", split="test")
- attacker = textattack.Attacker(attack, dataset, attack_args)
- attacker.attack_dataset()
- elif Dataset == 'RTE':
- # label_map = {0: ['Entailment', '蕴涵'], 1: ['Contradiction', '矛盾']}
- dataset = textattack.datasets.HuggingFaceDataset("glue", "rte", split="validation")
- attacker = textattack.Attacker(attack, dataset, attack_args)
- attacker.attack_dataset()
- elif Dataset == 'MRPC':
- # label_map = {0: ['Dissimilar', '不相似'], 1: ['Similar', '相似']}
- dataset = textattack.datasets.HuggingFaceDataset("glue", "mrpc", split="validation")
- attacker = textattack.Attacker(attack, dataset, attack_args)
- attacker.attack_dataset()
-
-
- if __name__ == "__main__":
- params = {
- 'dataset': 'RTE',
- 'model': 'XLNET',
- 'object': 'BuiltInSystem',
- 'attack_method': {
- 'method': 'checklist',
- 'sampleSize': 3
- }
- }
- TextAttack(params)
|