|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """FOTS eval."""
- import os
- import argparse
- import datetime
- import time
- import sys
- import ast
- import re
- from collections import defaultdict
- from collections import namedtuple
- import Polygon as plg
-
- import numpy as np
-
- from mindspore import Tensor
- from mindspore.context import ParallelMode
- from mindspore import context
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- import mindspore as ms
- import mindspore.ops as ops
-
- from src.fots import FOTS
- from src.logger import get_logger
- from src.config import configFOTS
- from src.fots_dataset import create_fots_dataset
- from src.modules.parse_polys import parse_polys
-
- parser = argparse.ArgumentParser('mindspore fots testing')
-
- # device related
- parser.add_argument('--device_target', type=str, default='Ascend',
- help='device where the code will be implemented. (Default: Ascend)')
-
- # dataset related
- parser.add_argument('--data_dir', type=str, default='/home/lzh/2021-9-25/dataset/ICDAR2015/task4_1/', help='test data dir')
- parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
-
- # network related
- parser.add_argument('--pretrained', default='/home/lzh/2021-9-25/preTrain_ckpt/fots_epoch582.ckpt', type=str, help='model_path, local pretrained model to load')
-
- # logging related
- parser.add_argument('--log_path', type=str, default='/home/lzh/2021-9-25/outputs/', help='checkpoint save location')
-
- # detect_related
- parser.add_argument('--result', type=str, default='/home/lzh/2021-9-25/res/', help='detect result')
- parser.add_argument('--nms_thresh', type=float, default=0.6, help='threshold for NMS')
- parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
- parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
- parser.add_argument('--multi_label', type=ast.literal_eval, default=True, help='whether to use multi label')
- parser.add_argument('--multi_label_thresh', type=float, default=0.1, help='threshhold to throw low quality boxes')
- parser.add_argument('--is_modelArts', type=int, default=0,
- help='Trainning in modelArts or not, 1 for yes, 0 for no. Default: 0')
-
- args, _ = parser.parse_known_args()
- args.rank = 0
-
- if args.is_modelArts: # 待改
- args.data_root = os.path.join(args.data_dir, 'val2017')
- args.ann_file = os.path.join(args.data_dir, 'annotations')
- import moxing as mox
-
- local_data_url = os.path.join('/cache/data', str(args.rank))
- local_annFile = os.path.join('/cache/data', str(args.rank))
- local_pretrained = os.path.join('/cache/data', str(args.rank))
-
- temp_str = args.pretrained.split('/')[-1]
- args.pretrained = args.pretrained[0:args.pretrained.rfind('/')]
-
- mox.file.copy_parallel(args.data_root, local_data_url)
- args.data_root = local_data_url
-
- mox.file.copy_parallel(args.ann_file, local_annFile)
- args.ann_file = os.path.join(local_data_url, 'instances_val2017.json')
-
- mox.file.copy_parallel(args.pretrained, local_pretrained)
- args.pretrained = os.path.join(local_data_url, temp_str)
- else:
- args.data_root = os.path.join(args.data_dir, 'ch4_test_images')
- args.ann_file = os.path.join(args.data_dir, 'Challenge4_Test_Task1_GT')
-
-
- class Redirct:
- def __init__(self):
- self.content = ""
-
- def write(self, content):
- self.content += content
-
- def flush(self):
- self.content = ""
-
-
- class DetectionEngine:
- """Detection engine."""
-
- def __init__(self):
- self.labels = ['Yes', 'No']
- self.num_classes = len(self.labels)
- self.results = {}
- self.file_path = ''
- self._img_ids = list()
- self.det_boxes = []
- self.coco_catIds = [1, 2]
-
- def polygon_from_points(self, points):
- """
- Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
- """
- resBoxes=np.empty([1,8],dtype='int32')
- resBoxes[0,0]=int(points[0])
- resBoxes[0,4]=int(points[1])
- resBoxes[0,1]=int(points[2])
- resBoxes[0,5]=int(points[3])
- resBoxes[0,2]=int(points[4])
- resBoxes[0,6]=int(points[5])
- resBoxes[0,3]=int(points[6])
- resBoxes[0,7]=int(points[7])
- pointMat = resBoxes[0].reshape([2,4]).T
- return plg.Polygon( pointMat)
-
- def rectangle_to_polygon(self, rect):
- resBoxes=np.empty([1,8],dtype='int32')
- resBoxes[0,0]=int(rect.xmin)
- resBoxes[0,4]=int(rect.ymax)
- resBoxes[0,1]=int(rect.xmin)
- resBoxes[0,5]=int(rect.ymin)
- resBoxes[0,2]=int(rect.xmax)
- resBoxes[0,6]=int(rect.ymin)
- resBoxes[0,3]=int(rect.xmax)
- resBoxes[0,7]=int(rect.ymax)
-
- pointMat = resBoxes[0].reshape([2,4]).T
-
- return plg.Polygon( pointMat)
-
- def rectangle_to_points(self, rect):
- points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
- return points
-
- def get_union(self, pD, pG):
- areaA = pD.area();
- areaB = pG.area();
- return areaA + areaB - self.get_intersection(pD, pG)
-
- # 得到iou
- def get_intersection_over_union(self, pD, pG):
- try:
- return self.get_intersection(pD, pG) / self.get_union(pD, pG);
- except:
- return 0
-
- def get_intersection(self, pD, pG):
- pInt = pD & pG
- if len(pInt) == 0:
- return 0
- return pInt.area()
-
- def compute_ap(self, confList, matchList, numGtCare):
- correct = 0
- AP = 0
- if len(confList)>0:
- confList = np.array(confList)
- matchList = np.array(matchList)
- sorted_ind = np.argsort(-confList)
- confList = confList[sorted_ind]
- matchList = matchList[sorted_ind]
- for n in range(len(confList)):
- match = matchList[n]
- if match:
- correct += 1
- AP += float(correct)/(n + 1)
-
- if numGtCare>0:
- AP /= numGtCare
-
- return AP
-
- def default_evaluation_params(self):
- """
- default_evaluation_params: Default parameters to use for the validation and evaluation.
- """
- return {
- 'IOU_CONSTRAINT' :0.5,
- 'AREA_PRECISION_CONSTRAINT' :0.5,
- 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
- 'CRLF':False, # Lines are delimited by Windows CRLF format
- 'CONFIDENCES':False, #Detections must include confidence value. AP will be calculated
- 'PER_SAMPLE_RESULTS':True #Generate per sample results and produce data for visualization
- }
-
- def decode_utf8(self, raw):
- """
- Returns a Unicode object on success, or None on failure
- """
- try:
- return raw.decode('utf-8-sig',errors = 'replace')
- except:
- return None
-
- def validate_clockwise_points(self, points):
- """
- Validates that the points are in clockwise order.
- """
- edge = []
- for i in range(len(points)//2):
- edge.append( (int(points[(i+1)*2 % len(points)]) - int(points[i*2])) * (int(points[ ((i+1)*2+1) % len(points)]) + int(points[i*2+1])) )
- if sum(edge)>0:
- raise Exception("Points are not clockwise. The coordinates of bounding points have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
-
- def get_tl_line_values(self, line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
- """
- Validate the format of the line. If the line is not valid an exception will be raised.
- If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
- Posible values are:
- LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
- LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
- Returns values from a textline. Points , [Confidences], [Transcriptions]
- """
- confidence = 0.0
- transcription = "";
- points = []
-
- numPoints = 4;
-
- if LTRB:
-
- numPoints = 4;
-
- if withTranscription and withConfidence:
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
- if m == None :
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
- raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
- elif withConfidence:
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
- if m == None :
- raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
- elif withTranscription:
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
- if m == None :
- raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
- else:
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
- if m == None :
- raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
-
- xmin = int(m.group(1))
- ymin = int(m.group(2))
- xmax = int(m.group(3))
- ymax = int(m.group(4))
- if(xmax<xmin):
- raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
- if(ymax<ymin):
- raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
-
- points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
-
- if (imWidth>0 and imHeight>0):
- validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
- validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
-
- else:
-
- numPoints = 8
-
- if withTranscription and withConfidence:
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
- if m == None :
- raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
- elif withConfidence:
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
- if m == None :
- raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
- elif withTranscription:
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
- if m == None :
- raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
- else:
- m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
- if m == None :
- raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
-
- points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
-
- self.validate_clockwise_points(points)
-
- if (imWidth>0 and imHeight>0):
- validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
- validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
- validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
- validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
-
-
- if withConfidence:
- try:
- confidence = float(m.group(numPoints+1))
- except ValueError:
- raise Exception("Confidence value must be a float")
-
- if withTranscription:
- posTranscription = numPoints + (2 if withConfidence else 1)
- transcription = m.group(posTranscription)
- m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
- if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
- transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
-
- return points,confidence,transcription
-
-
-
- def get_tl_line_values_from_file_contents(self, content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
- """
- Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
- xmin,ymin,xmax,ymax,[confidence],[transcription]
- x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
- """
- pointsList = []
- transcriptionsList = []
- confidencesList = []
-
- lines = content.split( "\r\n" if CRLF else "\n" )
- for line in lines:
- line = line.replace("\r","").replace("\n","")
- if(line != "") :
- points, confidence, transcription = self.get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
- pointsList.append(points)
- transcriptionsList.append(transcription)
- confidencesList.append(confidence)
-
- if withConfidence and len(confidencesList)>0 and sort_by_confidences:
- import numpy as np
- sorted_ind = np.argsort(-np.array(confidencesList))
- confidencesList = [confidencesList[i] for i in sorted_ind]
- pointsList = [pointsList[i] for i in sorted_ind]
- transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
-
- return pointsList,confidencesList,transcriptionsList
-
-
-
- def read_det_gt_txt2dict(self, gtFilePath, detFilePath):
- # det
- det = {}
- detTxtLists = os.listdir(detFilePath)
- detTxtLists.sort(key=lambda x: int(x[8:-4]))
-
- for i in range(1, len(detTxtLists)+1):
- txt = detTxtLists[i-1]
- f = open(detFilePath+txt, 'r') # 返回一个文件对象
- data = f.read() # 读取文件
- det[i] = data
-
- # gt
- gt = {}
- gtTxtLists = os.listdir(gtFilePath)
- gtTxtLists.sort(key=lambda x: int(x[7:-4]))
-
- for i in range(1, len(gtTxtLists)+1):
- txt = gtTxtLists[i-1]
- f = open(gtFilePath+txt, 'r') # 返回一个文件对象
- data = f.read() # 读取文件
- gt[i] = data
-
- return det, gt
-
-
- def evaluate_method(self, gtFilePath, detFilePath):
- """calculating evaluation result"""
-
- perSampleMetrics = {}
-
- evaluationParams = self.default_evaluation_params()
-
- matchedSum = 0
-
- Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
-
- det, gt = self.read_det_gt_txt2dict(gtFilePath, detFilePath)
-
- numGlobalCareGt = 0;
- numGlobalCareDet = 0;
-
- arrGlobalConfidences = []
- arrGlobalMatches = []
-
- for resFile in gt:
- gtFile = gt[resFile]
- recall = 0
- precision = 0
- hmean = 0
-
- detMatched = 0
-
- iouMat = np.empty([1,1])
-
- gtPols = []
- detPols = []
-
- gtPolPoints = []
- detPolPoints = []
-
- #Array of Ground Truth Polygons' keys marked as don't Care
- gtDontCarePolsNum = []
- #Array of Detected Polygons' matched with a don't Care GT
- detDontCarePolsNum = []
-
- pairs = []
- detMatchedNums = []
-
- arrSampleConfidences = []
- arrSampleMatch = []
- sampleAP = 0;
-
- evaluationLog = ""
-
- pointsList,_,transcriptionsList = self.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
- for n in range(len(pointsList)):
- points = pointsList[n]
- transcription = transcriptionsList[n]
- dontCare = transcription == "###"
- if evaluationParams['LTRB']:
- gtRect = Rectangle(*points)
- gtPol = self.rectangle_to_polygon(gtRect)
- else:
- gtPol = self.polygon_from_points(points)
- gtPols.append(gtPol)
- gtPolPoints.append(points)
- if dontCare:
- gtDontCarePolsNum.append( len(gtPols)-1 )
-
- evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
-
- if resFile in det:
-
- detFile = det[resFile]
-
- pointsList,confidencesList,_ = self.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],False,evaluationParams['CONFIDENCES'])
- for n in range(len(pointsList)):
- points = pointsList[n]
-
- if evaluationParams['LTRB']:
- detRect = Rectangle(*points)
- detPol = self.rectangle_to_polygon(detRect)
- else:
- detPol = self.polygon_from_points(points)
- detPols.append(detPol)
- detPolPoints.append(points)
- if len(gtDontCarePolsNum)>0 :
- for dontCarePol in gtDontCarePolsNum:
- dontCarePol = gtPols[dontCarePol]
- intersected_area = self.get_intersection(dontCarePol,detPol)
- pdDimensions = detPol.area()
- precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
- if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
- detDontCarePolsNum.append( len(detPols)-1 )
- break
-
- evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
-
- if len(gtPols)>0 and len(detPols)>0:
- #Calculate IoU and precision matrixs
- outputShape=[len(gtPols),len(detPols)]
- iouMat = np.empty(outputShape)
- gtRectMat = np.zeros(len(gtPols),np.int8)
- detRectMat = np.zeros(len(detPols),np.int8)
- for gtNum in range(len(gtPols)):
- for detNum in range(len(detPols)):
- pG = gtPols[gtNum]
- pD = detPols[detNum]
- iouMat[gtNum,detNum] = self.get_intersection_over_union(pD,pG)
-
- for gtNum in range(len(gtPols)):
- for detNum in range(len(detPols)):
- if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
- if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
- gtRectMat[gtNum] = 1
- detRectMat[detNum] = 1
- detMatched += 1
- pairs.append({'gt':gtNum,'det':detNum})
- detMatchedNums.append(detNum)
- evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n"
-
- if evaluationParams['CONFIDENCES']:
- for detNum in range(len(detPols)):
- if detNum not in detDontCarePolsNum :
- #we exclude the don't care detections
- match = detNum in detMatchedNums
-
- arrSampleConfidences.append(confidencesList[detNum])
- arrSampleMatch.append(match)
-
- arrGlobalConfidences.append(confidencesList[detNum]);
- arrGlobalMatches.append(match);
-
- numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
- numDetCare = (len(detPols) - len(detDontCarePolsNum))
- if numGtCare == 0:
- recall = float(1)
- precision = float(0) if numDetCare >0 else float(1)
- sampleAP = precision
- else:
- recall = float(detMatched) / numGtCare
- precision = 0 if numDetCare==0 else float(detMatched) / numDetCare
- if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']:
- sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
-
- hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
-
- matchedSum += detMatched
- numGlobalCareGt += numGtCare
- numGlobalCareDet += numDetCare
-
- if evaluationParams['PER_SAMPLE_RESULTS']:
- perSampleMetrics[resFile] = {
- 'precision':precision,
- 'recall':recall,
- 'hmean':hmean,
- 'pairs':pairs,
- 'AP':sampleAP,
- 'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
- 'gtPolPoints':gtPolPoints,
- 'detPolPoints':detPolPoints,
- 'gtDontCare':gtDontCarePolsNum,
- 'detDontCare':detDontCarePolsNum,
- 'evaluationParams': evaluationParams,
- 'evaluationLog': evaluationLog
- }
-
- # Compute MAP and MAR
- AP = 0
- if evaluationParams['CONFIDENCES']:
- AP = self.compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
-
- methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
- methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
- methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
-
- methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP }
-
- resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics}
-
- return resDict
-
-
-
- def convert_testing_shape(args_testing_shape): # TODO:待改
- """Convert testing shape to list."""
- testing_shape = [int(args_testing_shape), int(args_testing_shape)]
- return testing_shape
-
-
- if __name__ == "__main__":
- start_time = time.time()
- device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 2
- # device_id = 1
- context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, device_id=device_id) # PYNATIVE_MODE 和 GRAPH_MODE
-
- # logger
- args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
- rank_id = int(os.getenv('DEVICE_ID', '0'))
- args.logger = get_logger(args.outputs_dir, rank_id)
-
- context.reset_auto_parallel_context()
- parallel_mode = ParallelMode.STAND_ALONE
- context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
-
- args.logger.info('Creating Network....')
- network = FOTS()
- # 打印网络参数,及参数的形状
- # for m in network.parameters_and_names():
- # print(m)
- print(args)
- args.logger.info(args.pretrained)
- if os.path.isfile(args.pretrained):
- param_dict = load_checkpoint(args.pretrained)
- # for k, v in param_dict.items():
- # print(v.name, ' ', v.shape)
- # load_param_into_net(network, param_dict)
- param_dict_new = {}
- for key, values in param_dict.items():
- if key.startswith('moments.'):
- continue
- # conv1层
- elif key.startswith('fots_network.resnet.conv1.weight'):
- param_dict_new['conv1.0.weight'] = values
- elif key.startswith('fots_network.resnet.bn1.gamma'):
- param_dict_new['conv1.1.gamma'] = values
- elif key.startswith('fots_network.resnet.bn1.beta'):
- param_dict_new['conv1.1.beta'] = values
- elif key.startswith('fots_network.resnet.bn1.moving_mean'):
- param_dict_new['conv1.1.moving_mean'] = values
- elif key.startswith('fots_network.resnet.bn1.moving_variance'):
- param_dict_new['conv1.1.moving_variance'] = values
-
- # encoder层
- elif key.startswith('fots_network.resnet.'):
- param_dict_new[key[20:].replace('layer', 'encoder')] = values
-
- # 其他层
- elif key.startswith('fots_network.'):
- param_dict_new[key[13:]] = values
- else:
- param_dict_new[key] = values
- load_param_into_net(network, param_dict_new)
- args.logger.info('load_model {} success'.format(args.pretrained))
- else:
- args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained))
- assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained))
- exit(1)
-
- data_root = args.data_root
- ann_file = args.ann_file
-
- config = configFOTS()
- if args.testing_shape:
- config.test_img_shape = convert_testing_shape(args.testing_shape)
-
- ds, ds_origin = create_fots_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size,
- max_epoch=1, device_num=1, rank=rank_id, shuffle=False, config=config)
- data_size = len(ds_origin)
-
- args.logger.info('total {} images to eval'.format(data_size))
-
- network.set_train(False)
-
- # init detection engine
- detection = DetectionEngine()
-
- args.logger.info('Start inference....')
- for image_index, data in enumerate(ds.create_dict_iterator(output_numpy=True, num_epochs=1)):
-
- img_idx = data["img_idx"].astype(np.int32)
- prefix = ds_origin.image_prefix[img_idx[0]]
- image = Tensor.from_numpy(np.ascontiguousarray(np.transpose(data["image"], axes=(0, 3, 1, 2))))
-
- scale_x = data["scale"][0][0].item()
- scale_y = data["scale"][0][1].item()
-
- confidence, distances, angle = network(image)
-
- sigmoid = ops.Sigmoid()
- squeeze = ops.Squeeze()
-
- confidence = squeeze(sigmoid(confidence)).asnumpy()
- distances = squeeze(distances).asnumpy()
- angle = squeeze(angle).asnumpy()
-
- polys = parse_polys(confidence, distances, angle, 0.95, 0.3)#, img=orig_scaled_image)
- with open('{}'.format(os.path.join(args.result, 'detect/', 'res_{}.txt'.format(prefix))), 'w') as f:
- for id in range(polys.shape[0]):
- f.write('{}, {}, {}, {}, {}, {}, {}, {}\n'.format(
- int(polys[id, 0] / scale_x), int(polys[id, 1] / scale_y), int(polys[id, 2] / scale_x), int(polys[id, 3] / scale_y),
- int(polys[id, 4] / scale_x), int(polys[id, 5] / scale_y), int(polys[id, 6] / scale_x), int(polys[id, 7] / scale_y)
- ))
-
- if image_index % 10 == 0:
- args.logger.info('Processing... {:.2f}% '.format(image_index * args.per_batch_size / data_size * 100))
-
- args.logger.info('Calculating mAP...')
- resDict = detection.evaluate_method(args.ann_file+'/', args.result+'detect/')
-
-
- cost_time = time.time() - start_time
- args.logger.info('\n=============fots eval reulst=========\n')
-
- args.logger.info('precision:{}'.format(resDict['method']['precision']))
- print('precision:{}'.format(resDict['method']['precision']))
- args.logger.info('recall:{}'.format(resDict['method']['recall']))
- print('recall:{}'.format(resDict['method']['recall']))
- args.logger.info('hmean:{}'.format(resDict['method']['hmean']))
- print('hmean:{}'.format(resDict['method']['hmean']))
- args.logger.info('AP:{}'.format(resDict['method']['AP']))
- print('AP:{}'.format(resDict['method']['AP']))
- args.logger.info('testing cost time {:.2f}s'.format(cost_time))
- print('testing cost time {:.2f}s'.format(cost_time))
|