|
- import colorsys
- import os
- import time
- import numpy as np
-
-
- from keras import backend as K
- #from tensorflow.compat.v1.keras import backend as K
- import tensorflow as tf
- tf.compat.v1.disable_eager_execution()
- from tensorflow.compat.v1 import ConfigProto
- from tensorflow.compat.v1 import InteractiveSession
- config = ConfigProto()
- config.gpu_options.allow_growth = True
- session = InteractiveSession(config=config)
-
-
- from PIL import ImageDraw, ImageFont
-
- from nets.yolo import yolo_body
- from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
- resize_image)
- from utils.utils_bbox import DecodeBox
-
-
- class YOLO(object):
- _defaults = {
- #--------------------------------------------------------------------------#
- # 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
- # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
- #
- # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
- # 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
- # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
- #--------------------------------------------------------------------------#
- "model_path" : 'model_data/ep060-loss7.310-val_loss7.303.h5',
- #"model_path" : 'logs/ep060-loss7.310-val_loss7.303.h5',
- "classes_path" : 'model_data/labels.txt',
- #---------------------------------------------------------------------#
- # anchors_path代表先验框对应的txt文件,一般不修改。
- # anchors_mask用于帮助代码找到对应的先验框,一般不修改。
- #---------------------------------------------------------------------#
- "anchors_path" : 'model_data/yolo_anchors.txt',
- "anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
- #---------------------------------------------------------------------#
- # 输入图片的大小,必须为32的倍数。
- #---------------------------------------------------------------------#
- "input_shape" : [416, 416],
- #---------------------------------------------------------------------#
- # 只有得分大于置信度的预测框会被保留下来,置信度门限
- #---------------------------------------------------------------------#
- "confidence" : 0.5,
- #---------------------------------------------------------------------#
- # 非极大抑制所用到的nms_iou大小
- #---------------------------------------------------------------------#
- "nms_iou" : 0.3,
- #---------------------------------------------------------------------#
- # 最大目标数量
- #---------------------------------------------------------------------#
- "max_boxes" : 100,
- #---------------------------------------------------------------------#
- # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
- # 在多次测试后,发现关闭letterbox_image直接resize的效果更好
- #---------------------------------------------------------------------#
- "letterbox_image" : False,
- }
-
- @classmethod
- def get_defaults(cls, n):
- if n in cls._defaults:
- return cls._defaults[n]
- else:
- return "Unrecognized attribute name '" + n + "'"
-
- #---------------------------------------------------#
- # 初始化yolo
- #---------------------------------------------------#
- def __init__(self, **kwargs):
- self.__dict__.update(self._defaults)
- for name, value in kwargs.items():
- setattr(self, name, value)
-
- #---------------------------------------------------#
- # 获得种类和先验框的数量
- #---------------------------------------------------#
- self.class_names, self.num_classes = get_classes(self.classes_path)
- self.anchors, self.num_anchors = get_anchors(self.anchors_path)
-
- #---------------------------------------------------#
- # 画框设置不同的颜色
- #---------------------------------------------------#
- hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
- self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
- self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
-
- self.input_image_shape = K.placeholder(shape=(2, ))
- self.sess = K.get_session()
- self.boxes, self.scores, self.classes = self.generate()
-
- #---------------------------------------------------#
- # 载入模型
- #---------------------------------------------------#
- def generate(self):
- model_path = os.path.expanduser(self.model_path)
- assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
-
- self.yolo_model = yolo_body([None, None, 3], self.anchors_mask, self.num_classes)
- self.yolo_model.load_weights(self.model_path)
- print('{} model, anchors, and classes loaded.'.format(model_path))
- #---------------------------------------------------------#
- # 在yolo_eval函数中,我们会对预测结果进行后处理
- # 后处理的内容包括,解码、非极大抑制、门限筛选等
- #---------------------------------------------------------#
- boxes, scores, classes = DecodeBox(
- self.yolo_model.output,
- self.anchors,
- self.num_classes,
- self.input_image_shape,
- self.input_shape,
- anchor_mask = self.anchors_mask,
- max_boxes = self.max_boxes,
- confidence = self.confidence,
- nms_iou = self.nms_iou,
- letterbox_image = self.letterbox_image
- )
- return boxes, scores, classes
-
- #---------------------------------------------------#
- # 检测图片
- #---------------------------------------------------#
- def detect_image(self, image, crop = False):
- #---------------------------------------------------------#
- # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
- # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
- #---------------------------------------------------------#
- image = cvtColor(image)
- #---------------------------------------------------------#
- # 给图像增加灰条,实现不失真的resize
- # 也可以直接resize进行识别
- #---------------------------------------------------------#
- image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
- #---------------------------------------------------------#
- # 添加上batch_size维度,并进行归一化
- #---------------------------------------------------------#
- image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
-
- #---------------------------------------------------------#
- # 将图像输入网络当中进行预测!
- #---------------------------------------------------------#
- out_boxes, out_scores, out_classes = self.sess.run(
- [self.boxes, self.scores, self.classes],
- feed_dict={
- self.yolo_model.input: image_data,
- self.input_image_shape: [image.size[1], image.size[0]],
- K.learning_phase(): 0})
-
- print('Found {} boxes for {}'.format(len(out_boxes), 'img'))
- #---------------------------------------------------------#
- # 设置字体与边框厚度
- #---------------------------------------------------------#
- font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
- thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
-
- #---------------------------------------------------------#
- # 是否进行目标的裁剪
- #---------------------------------------------------------#
- if crop:
- for i, c in list(enumerate(out_boxes)):
- top, left, bottom, right = out_boxes[i]
- top = max(0, np.floor(top).astype('int32'))
- left = max(0, np.floor(left).astype('int32'))
- bottom = min(image.size[1], np.floor(bottom).astype('int32'))
- right = min(image.size[0], np.floor(right).astype('int32'))
-
- dir_save_path = "img_crop"
- if not os.path.exists(dir_save_path):
- os.makedirs(dir_save_path)
- crop_image = image.crop([left, top, right, bottom])
- crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
- print("save crop_" + str(i) + ".png to " + dir_save_path)
- #---------------------------------------------------------#
- # 图像绘制
- #---------------------------------------------------------#
- for i, c in list(enumerate(out_classes)):
- predicted_class = self.class_names[int(c)]
- box = out_boxes[i]
- score = out_scores[i]
-
- top, left, bottom, right = box
-
- top = max(0, np.floor(top).astype('int32'))
- left = max(0, np.floor(left).astype('int32'))
- bottom = min(image.size[1], np.floor(bottom).astype('int32'))
- right = min(image.size[0], np.floor(right).astype('int32'))
-
- label = '{} {:.2f}'.format(predicted_class, score)
- draw = ImageDraw.Draw(image)
- label_size = draw.textsize(label, font)
- label = label.encode('utf-8')
- print(label, top, left, bottom, right)
-
- if top - label_size[1] >= 0:
- text_origin = np.array([left, top - label_size[1]])
- else:
- text_origin = np.array([left, top + 1])
-
- for i in range(thickness):
- draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
- draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
- draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
- del draw
-
- return image
-
- def get_FPS(self, image, test_interval):
- #---------------------------------------------------------#
- # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
- # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
- #---------------------------------------------------------#
- image = cvtColor(image)
- #---------------------------------------------------------#
- # 给图像增加灰条,实现不失真的resize
- # 也可以直接resize进行识别
- #---------------------------------------------------------#
- image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
- #---------------------------------------------------------#
- # 添加上batch_size维度,并进行归一化
- #---------------------------------------------------------#
- image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
-
- #---------------------------------------------------------#
- # 将图像输入网络当中进行预测!
- #---------------------------------------------------------#
- out_boxes, out_scores, out_classes = self.sess.run(
- [self.boxes, self.scores, self.classes],
- feed_dict={
- self.yolo_model.input: image_data,
- self.input_image_shape: [image.size[1], image.size[0]],
- K.learning_phase(): 0})
-
- t1 = time.time()
- for _ in range(test_interval):
- out_boxes, out_scores, out_classes = self.sess.run(
- [self.boxes, self.scores, self.classes],
- feed_dict={
- self.yolo_model.input: image_data,
- self.input_image_shape: [image.size[1], image.size[0]],
- K.learning_phase(): 0})
- t2 = time.time()
- tact_time = (t2 - t1) / test_interval
- return tact_time
-
- def get_map_txt(self, image_id, image, class_names, map_out_path):
- f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
- #---------------------------------------------------------#
- # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
- #---------------------------------------------------------#
- image = cvtColor(image)
- #---------------------------------------------------------#
- # 给图像增加灰条,实现不失真的resize
- # 也可以直接resize进行识别
- #---------------------------------------------------------#
- image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
- #---------------------------------------------------------#
- # 添加上batch_size维度,并进行归一化
- #---------------------------------------------------------#
- image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
-
- out_boxes, out_scores, out_classes = self.sess.run(
- [self.boxes, self.scores, self.classes],
- feed_dict={
- self.yolo_model.input: image_data,
- self.input_image_shape: [image.size[1], image.size[0]],
- K.learning_phase(): 0
- })
-
- for i, c in enumerate(out_classes):
- predicted_class = self.class_names[int(c)]
- score = str(out_scores[i])
- top, left, bottom, right = out_boxes[i]
- if predicted_class not in class_names:
- continue
-
- f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
-
- f.close()
- return
-
- def close_session(self):
- self.sess.close()
|