|
- from queue import Queue
- import threading
- import paho.mqtt.client as mqtt
- import json
- import ctypes
- import inspect
- import sys
-
-
- sys.path.append("..")
- sys.path.append("...")
- sys.path.append("/data0/BigPlatform/ZJPlatform/009_Visualization/torch-cam-master/scripts/")
- sys.path.append("/data0/BigPlatform/ZJPlatform/009_Visualization/t-SNE/")
-
-
- """ 常量定义 """
- # EMQ 订阅topic
- TOPIC_STOP_TASK = 'textattackstop'
- TOPIC_TRAIN_START = 'trainstart'
- TOPIC_ATTACK_START = 'textattackstart'
-
- # TEXT_PAUSE_TASK="textpausetask"
- # TEXT_CONTINUE_TASK="textcontinuetask"
-
- # 最大线程数/任务数量
- MAX_THREAD = 5
- # 任务字典,根据任务id索引任务执行的Future
- taskMap = {}
-
-
- """ 路由方法 """
-
-
- # 停止任务
- def stop_task(queue: Queue, topic, payload):
- taskId = payload['taskId']
- task = taskMap[str(taskId)]
- complete = {
- 'topic' : 'cancel', # 强制终止的topic
- 'data' : {
- 'taskId' : taskId
- }
- }
- if task.is_alive():
- # 该任务正在执行中,需要强制结束
- stop_thread(task)
- else:
- complete['topic'] = 'finish'
-
- queue.put(complete)
-
-
- # 开始训练任务
- def train_start(queue: Queue, topic, payload):
- taskId = payload['taskId']
- data = payload['data']
- # 解析参数
- platform = data['platform']
- model = data['model']
- dataset = data['dataset']
- attack = data['attack']
- # 这里可以加上校验提前检查是否合法
- # TODO
-
- # 重新包装参数
- params = {
- 'platform' : platform,
- 'model' : model,
- 'dataset' : dataset,
- 'attack': attack,
- 'taskId' : taskId,
- 'queue' : queue
- }
-
- # 执行算法部分
- # TODO
- # 首先检查任务队列是否已经满
- if can_doing():
- # TODO START
- # from Image_Train_Test_dj import Test
- # t = threading.Thread(target=Test, args=(params,))
- pass
- # TODO END
-
- # t.setDaemon(True)
- # t.start()
- # 放入索引
- # taskMap[str(taskId)] = t
- else:
- complete = {
- 'topic': 'finish',
- 'data': {
- 'taskId': taskId,
- 'message': '任务队列已满'
- }
- }
- queue.put(complete)
-
-
- def MultiTherad(ParamsSet,target):
- """
-
- :param ParamsSet: 输入的参数集合
- :param target: 需要调用的函数名
- :return:
- """
- Thread = []
- for i in range(len(ParamsSet)):
- Thread.append(threading.Thread(target=target, args=(ParamsSet[i],)))
- for i in range(0,len(ParamsSet),2):
- if i+1>=len(ParamsSet):
- Thread[i].setDaemon(True)
- Thread[i].start()
- # Thread[i].join()
- else:
- Thread[i].setDaemon(True)
- Thread[i+1].setDaemon(True)
- Thread[i].start()
- Thread[i+1].start()
- # Thread[i].join()
- # Thread[i+1].join()
- for i in range(len(ParamsSet)):
-
- taskMap[str(ParamsSet[i]["taskId"])] = Thread[i]
-
-
- def SingleTherad(ParamsSet,target):
- Thread = []
- for i in range(len(ParamsSet)):
- Thread.append(threading.Thread(target=target, args=(ParamsSet[i],)))
- for i in range(0,len(ParamsSet),1):
- Thread[i].setDaemon(True)
- Thread[i].start()
- # Thread[i].join()
- for i in range(len(ParamsSet)):
- taskMap[str(ParamsSet[i]["taskId"])] = Thread[i]
-
-
- # 开始攻击任务
- def attack_start(queue: Queue, topic, payload):
-
- taskId = payload['taskId'] #用户ID,用来构建唯一存放路径时使用
- dataModal = payload["dataModal"] #检测所用的数据模态
- step = payload["step"] #需要开始的阶段
- data = payload["data"]
- print(data)
- dataAppMode = data["appMode"] #检测是单机模式还是联邦学习模式
- dataEvaluationObject = data["evaluationObject"]
- dataSceneName = data["sceneName"]
- dataPlatform = data['platformFramework']
- dataDataset = data['dataset']
- dataDepthModel = data['depthModel']
-
- whiteBoxAttackSet = [attack_method for attack_method in
- data["vulnerabilityAssessmentMethod"]['whiteBoxAttack']]
- # blackBoxAttackSet = [attack_method for attack_method in
- # data["vulnerabilityAssessmentMethod"]['blackBoxAttack']]
-
- # dataRobust = data["robust"]
- # dataInterpretableSemanticsVisualizationSet = [method for method in data["interpretableSemantics"]["Visualization"]]
- # dataInterpretableSemanticsExplainableSet = [method for method in
- # data["interpretableSemantics"][
- # "Explainable"]]
- # dataInterpretableSemanticsBorderVisualizationSet = [method for method in
- # data["interpretableSemantics"][
- # "BorderVisualization"]]
-
- # 这里可以加上校验提前检查是否合法
- # TODO
-
- # 重新包装参数集合
- attackPayloadSet = []
-
- attackPayloadSet.extend(whiteBoxAttackSet)
-
- # attackPayloadSet.extend(blackBoxAttackSet)
-
-
-
-
- # 把参数包装拆开后组成N个字典,准备进行多线程
- attackParamsSet = []
-
- # 执行算法部分
- # TODO
- # 首先检查任务队列是否已经满
- if can_doing():
- # TODO START
- # 两个攻击同时进行
- from TextAttack import TextAttack
-
- for i in range(len(attackPayloadSet)):
- attackParamsSet.append(
- {
- 'platform': dataPlatform,
- 'model': dataDepthModel,
- 'dataset': dataDataset,
- 'taskId': taskId,
- 'object': dataEvaluationObject,
- 'queue': queue,
- 'attack_method': attackPayloadSet[i],
- 'dataModel': dataModal
- }
- )
- if step == "ATTACK":
- if dataEvaluationObject == "BuiltInSystem":
- # Step2:载入需要用到的函数
- from TextAttack import TextAttack
- # Step3:定义多线程工作
- print("Step: ATTACK")
- MultiTherad(attackParamsSet, TextAttack)
-
- # Step2:从payload解析出需要的参数
- NeuralParamsSet = attackParamsSet
- if step == "NEURON_VIS":
- # Step2:载入需要用到的函数
- print("Step: NEURON_VIS")
- # from Image_Testing import Image_Testing
- # # Step3:定义单线程工作
- # SingleTherad(NeuralParamsSet, Image_Testing)
-
- # Step1:从payload解析出需要的参数
- VisParamsSet = attackParamsSet
- if step == "SAMPLE_VIS":
- # Step2:载入需要用到的函数
- print("Step: SAMPLE_VIS")
- # from all_cam_dj import Main as ImageVis
- # for i in range(len(VisParamsSet)):
- # VisParamsSet[i]["VisMethod"] = dataInterpretableSemanticsVisualizationSet
- # # Step3:定义多线程工作
- # MultiTherad(VisParamsSet, ImageVis)
-
- # Step1:从payload解析出需要的参数
- # TSNEParamsSet = attackParamsSet
- # if step == "EDGE_VIS":
- # # Step2:载入需要用到的函数
- # from tsne import TSNE
- # # Step3:定义多线程工作
- # MultiTherad(TSNEParamsSet, TSNE)
-
- # Step1:从payload解析出需要的参数
- RobustParamsSet = attackParamsSet
- if step == "ROBUST":
- # Step2:载入需要用到的函数
- print("Step: ROBUST")
- # from Image_Robustness import Image_Robustness
- # Step3:定义多线程工作
- # MultiTherad(RobustParamsSet, Image_Robustness)
- # except:
- # print("多线程有问题的,淦")
-
-
- # attackThread = []
- #
- # for i in range(len(attackParamsSet)):
- # attackThread.append(threading.Thread(target=Text_Attack, args=(attackParamsSet[i],)))
- #
- # for i in range(0,len(attackParamsSet),2):
- # if i+1>=len(attackParamsSet):
- # attackThread[i].setDaemon(True)
- # attackThread[i].start()
- # # attackThread[i].join()
- # else:
- # print("双线程")
- # attackThread[i].setDaemon(True)
- # attackThread[i+1].setDaemon(True)
- # attackThread[i].start()
- # attackThread[i+1].start()
- # # attackThread[i].join()
- # # attackThread[i+1].join()
- # for i in range(len(attackParamsSet)):
- # # 放入索引
- # taskMap[str(taskId)] = attackThread[i]
-
- # TODO END
-
- else:
- complete = {
- 'topic': 'finish',
- 'data': {
- 'taskId': taskId,
- 'message': '任务队列已满'
- }
- }
- queue.put(complete)
-
-
- def no_topic(queue: Queue, topic, payload):
- """
- 不支持的算法类型
- """
- taskId = payload['taskId']
- # payload是创建攻击的参数
- complete = {
- 'topic' : 'finish',
- 'data' : {
- 'taskId' : taskId,
- 'message' : '不支持此类型的计算:{}'.format(topic)
- }
- }
- queue.put(complete)
-
-
- def consume(client: mqtt.Client, queue: Queue):
- """ 这是一个队列任务消费者
-
- 守护线程方法,用于从队列中获取数据并发送给指定topic
- """
- while True:
- obj = queue.get()
- topic = obj['topic']
- print(obj)
-
- # 发送消息
- client.publish(topic, json.dumps(obj, ensure_ascii=False), 1)
- queue.task_done()
-
-
- def can_doing() -> bool:
- """ 检查能否执行一个新的任务
-
- 计算当前正在执行的线程数量,如果大于等于5个则返回False。
- 如果有部分线程已经执行完毕,则将其从taskMap中删除
- """
- # 线程结束队列
- complete_list = []
- # 活动线程数
- alive_counter = 0
-
- for key, value in taskMap.items():
- if value.is_alive() == False:
- complete_list.append(key)
- else:
- alive_counter += 1
-
- for taskId in complete_list:
- taskMap.pop(taskId)
- return alive_counter < 5
-
-
- def __async__raise(tid, exctype):
- """ 通过引发异常结束线程 """
- tid = ctypes.c_long(tid)
- if not inspect.isclass(exctype):
- exctype = type(exctype)
- res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
- if res == 0:
- raise ValueError("invalid thread id")
- elif res != 1:
- # """if it returns a number greater than one, you're in trouble,
- # and you should call it again with exc=NULL to revert the effect"""
- ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
- raise SystemError("PyThreadState_SetAsyncExc failed")
-
-
- def stop_thread(thread: threading.Thread):
- """
- 线束线程
- """
- __async__raise(thread.ident, SystemExit)
-
-
- # 服务类
- class Server(object):
- def __init__(self, client: mqtt.Client):
- self.client = client
- self.router = {
- TOPIC_STOP_TASK : stop_task,
- TOPIC_TRAIN_START : train_start,
- TOPIC_ATTACK_START : attack_start
- }
- self.queue = Queue(maxsize=10) # 最大允许10个任务堆积
-
- def register(self):
- """
- 给emq客户端注册对应的topic
- """
- self.client.subscribe(TOPIC_STOP_TASK) # 终止任务
- self.client.subscribe(TOPIC_TRAIN_START) # 开始训练任务
- self.client.subscribe(TOPIC_ATTACK_START) # 开始攻击任务
-
- t = threading.Thread(target=consume, args=(self.client, self.queue))
- t.setDaemon(True)
- t.start()
-
- def do_message(self, topic, payload):
- """
- 执行对应topic的任务
- """
-
- print("_++++++++++++++++++++++++{}".format(topic))
- self.router.get(topic, no_topic)(self.queue, topic, payload)
|