|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- # @Created : 2021/05/8
- # @Author : Koala
- # @FileName: stall.py
-
- import os
- import cv2
- import numpy as np
- from typing import Optional
- from run import TensorRTRun
- from utils import IoU
-
-
- def report_alert(t, x1, y1, x2, y2):
- print(f"There is a alert: {t, x1, y1, x2, y2}")
-
-
- class IllegalParking:
- def __init__(self, onnx_path='./sources/stall_model.onnx',
- trt_path='./sources/stall_model.trt',
- alert_boxes_filename='./sources/stall_alert_boxes.txt',
- max_batch_size=1, #支持多batch inference
- iou_threshold=0.3, #目标和警戒区域的重叠占比
- score_threshold=0.4, #目标检测器阈值
- detect_fps=25, #每秒检测的帧数
- alter_time=20, #连续检测时间,单位为秒
- alter_ratio=0.5, #检测丢帧比列
- same_threshold=0.7, #定位同一个目标的iou
- verbose=True, #是否可视化视频
- ):
- self.boxes = np.zeros([0, 4])
- self.runner = None # type: Optional[TensorRTRun]
- self.iou_threshold = iou_threshold
- self.score_threshold = score_threshold
- self.detect_fps = detect_fps # 每秒检测几帧的图片
- assert 0 <= alter_ratio <= 1
- self.alter_length = int(alter_time * detect_fps) # 摊位被检测多少次后报警
- self.alter_number = int(self.alter_length * alter_ratio) # 相同摊位在最近<alter_length>被检测多少次被认为违规
- self.same_threshold = same_threshold # 被认为是同一摊位的iou阈值
- print(f'每秒检测: {self.detect_fps} 帧视频, 当IoU>={self.same_threshold}时, 认为是同一摊位')
- print(f'在连续的 {self.alter_length} 次检测中, 同一摊位至少检测到 {self.alter_number} 次则认为是违规并报警')
-
- self.read_alert_boxes(alert_boxes_filename)
- self.load_model(trt_path, onnx_path, max_batch_size, True)
- self._old_detected = np.zeros([0, 5], dtype=np.int)
- self._num_detected = []
- self.verbose = verbose
-
- ### 读取警戒框
- def read_alert_boxes(self, filename):
- assert os.path.isfile(filename), f'警戒框文件"{filename}"并不存在'
- with open(filename, 'r') as f:
- self.boxes = np.array([list(map(int, s.split())) for s in f.readlines() if s.strip()])
- print(f'从文件 {filename} 读取到的警戒框有:')
- for box in self.boxes:
- print(f'\t({box[0]}, {box[1]}), ({box[2]} {box[3]})')
-
- ### 加载或构建trt模型
- def load_model(self, trt_path, onnx_path='', max_batch_size=1, fp16=True):
- self.runner = TensorRTRun(trt_path)
- if os.path.isfile(trt_path):
- print(f'从"{trt_path}"加载trt模型')
- self.runner.load_engine()
- else:
- assert os.path.isfile(onnx_path), f'ONNX模型文件"{onnx_path}"并不存在'
- self.runner.build_engine(onnx_path, fp16=True, max_batch_size=max_batch_size)
- print(f'使用ONNX模型"{onnx_path}"构建trt模型, 并保存在"{trt_path}')
- self.runner.prepare()
- print('已完成准备工作')
- return
-
- def detect(self, img):
- detected = self.runner(img)[0]
- if len(detected) == 0:
- return detected[:, :4]
-
- print('店外经营:')
- detected = detected[detected[:, 4] >= self.score_threshold, :4].astype(np.int)
-
- iou = IoU(detected, self.boxes)
- is_in_alert = np.any(iou > self.iou_threshold, axis=1)
- if self.verbose:
- for i, (x1, y1, x2, y2) in enumerate(detected):
- cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 1 + is_in_alert[i] * 2)
- return detected[is_in_alert, :]
-
- def __call__(self, video_path):
- self._old_detected = np.zeros([0, 5], dtype=np.int)
- self._num_detected = []
-
- # assert os.path.isfile(video_path)
- if self.verbose:
- cv2.namedWindow(video_path)
- cap = cv2.VideoCapture(video_path)
- W = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
- H = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
- print(f'视频大小: WxH={W}x{H}')
- last_t = -1
- last_alert = []
-
- while True:
- ret, frame = cap.read()
- if not ret:
- return
- second = cap.get(cv2.CAP_PROP_POS_MSEC) * 1e-3
-
- now_t = int(second * self.detect_fps)
- if now_t != last_t:
- boxes = self.detect(frame)
-
- n = len(boxes)
- if len(self._num_detected) == self.alter_length:
- self._old_detected = self._old_detected[self._num_detected[0]:, :]
- self._num_detected = self._num_detected[1:]
- boxes = np.concatenate([boxes, np.full([n, 1], now_t)], axis=1)
-
- last_alert = []
- if n > 0:
- is_same = IoU(boxes, self._old_detected, False) >= self.same_threshold
- for i in range(n):
- need_alert = np.unique(self._old_detected[is_same[i], 4]).size + 1 >= self.alter_number
- if need_alert:
- report_alert(second, *boxes[i, :4])
- last_alert.append(boxes[i, :4])
- self._num_detected.append(n)
- self._old_detected = np.concatenate([self._old_detected, boxes], axis=0)
- last_t = now_t
- if not self.verbose:
- continue
- ## 展示警戒框
- cv2.putText(frame, f'time: {second:.2f}s', (20, 50), 2, 2, (255, 255, 0))
- for x1, y1, x2, y2 in self.boxes:
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3)
- ## 展示违规摊位
- for x1, y1, x2, y2 in last_alert:
- # cv2.putText(frame, f'Illegal Parking', (x1, y1), 2, 2, (0, 255, 0))
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
- cv2.imshow(video_path, frame)
- k = cv2.waitKey(1)
- if k in [27, ord('q'), ord('Q')]:
- break
- elif k in [ord(' ')]:
- while True:
- cv2.imshow(video_path, frame)
- k = cv2.waitKey(100)
- if k in [27, ord('q'), ord('Q')]:
- return
- elif k == 32:
- break
- # break
- return
-
-
- def main(
- video_path ='2.mp4', #视频源
- onnx_path ='./sources/stall_model.onnx', #onnx模型路径
- trt_path ='./sources/stall_model.trt', #trt序列化模型保存路径
- alert_boxes_filename ='./sources/stall_alert_boxes.txt', #违停区域坐标存储文件路径
- ):
- e = IllegalParking(onnx_path, trt_path, alert_boxes_filename)
- e(video_path)
-
-
- if __name__ == '__main__':
- main()
|