|
- import tensorflow as tf
- from absl import app, flags, logging
- from absl.flags import FLAGS
- from core.yolov4 import YOLO, decode, filter_boxes
- import core.utils as utils
- from core.config import cfg
-
- flags.DEFINE_string('weights', './data/yolov4.weights', 'path to weights file')
- flags.DEFINE_string('output', './checkpoints/yolov4-416', 'path to output')
- flags.DEFINE_boolean('tiny', False, 'is yolo-tiny or not')
- flags.DEFINE_integer('input_size', 416, 'define input size of export model')
- flags.DEFINE_float('score_thres', 0.2, 'define score threshold')
- flags.DEFINE_string('framework', 'tf', 'define what framework do you want to convert (tf, trt, tflite)')
- flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
-
- def save_tf():
- STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
-
- input_layer = tf.keras.layers.Input([FLAGS.input_size, FLAGS.input_size, 3])
- feature_maps = YOLO(input_layer, NUM_CLASS, FLAGS.model, FLAGS.tiny)
- bbox_tensors = []
- prob_tensors = []
- if FLAGS.tiny:
- for i, fm in enumerate(feature_maps):
- if i == 0:
- output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
- else:
- output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
- bbox_tensors.append(output_tensors[0])
- prob_tensors.append(output_tensors[1])
- else:
- for i, fm in enumerate(feature_maps):
- if i == 0:
- output_tensors = decode(fm, FLAGS.input_size // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
- elif i == 1:
- output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
- else:
- output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework)
- bbox_tensors.append(output_tensors[0])
- prob_tensors.append(output_tensors[1])
- pred_bbox = tf.concat(bbox_tensors, axis=1)
- pred_prob = tf.concat(prob_tensors, axis=1)
- if FLAGS.framework == 'tflite':
- pred = (pred_bbox, pred_prob)
- else:
- boxes, pred_conf = filter_boxes(pred_bbox, pred_prob, score_threshold=FLAGS.score_thres, input_shape=tf.constant([FLAGS.input_size, FLAGS.input_size]))
- pred = tf.concat([boxes, pred_conf], axis=-1)
- model = tf.keras.Model(input_layer, pred)
- utils.load_weights(model, FLAGS.weights, FLAGS.model, FLAGS.tiny)
- model.summary()
- model.save(FLAGS.output)
-
- def main(_argv):
- save_tf()
-
- if __name__ == '__main__':
- try:
- app.run(main)
- except SystemExit:
- pass
|