|
- import os
- import random
- import warnings
- import numpy as np
- import torch
- from models.mynet.mynet4 import MyNet
- from PIL import Image
- import transforms as T
- import torch.nn as nn
- from utils.c_train_and_eval import evaluate
- from utils.change_data import MyDataset
-
- random.seed(47)
-
-
- class SegmentationPresetTrain:
- def __init__(self, hflip_prob=0.5, vflip_prob=0.5,
- # mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), size=256):
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), size=256):
- trans = []
- # trans.append(T.CenterCrop(0.5, size))
- # if hflip_prob > 0:
- # trans.append(T.RandomHorizontalFlip(hflip_prob))
- # if vflip_prob > 0:
- # trans.append(T.RandomVerticalFlip(vflip_prob))
- trans.append(T.RandomRotation(0.5))
- # trans.append(T.RandomEqualize(0.5))
- # trans.append(T.GaussianBlur(0.5))
- trans.extend([
- T.ToTensor(),
- T.Normalize(mean=mean, std=std),
- ])
- self.transforms = T.Compose(trans)
-
- def __call__(self, image1, image2, target):
- return self.transforms(image1, image2, target)
-
-
- def get_transform(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
- return SegmentationPresetTrain(mean=mean, std=std)
-
-
- warnings.filterwarnings("ignore")
-
- """
- SNUnet r"D:\模型权重文件\建筑物\20230412-033255\SNUnet.pth"
- "D:\模型权重文件\建筑物\20230412-072830\FC_EF.pth"
-
- GZ:"D:\模型权重文件\建筑物\GZ\20230531-105508\best.pth" 88.4
- D:\Datasets\Data_CD\CD_Data_GZ\CD_Data_GZ\256\test
- D:\Download\SiamUnet_diff_best (3).pth 89.2
- "D:\Download\SiamUnet_diff_best (4).pth" val 87.22 最好
-
- WHU:"D:\模型权重文件\建筑物\WHU\20230525-205745\best.pth"
- D:\Datasets\Data_CD\WHU\1\out\test
-
-
- """
-
-
- def parse_args():
- import argparse
- parser = argparse.ArgumentParser(description="pytorch fcn training")
- parser.add_argument("--ckpt_url", default=r"D:\Download\104model_ema_best.pth",
- help="data root")
- parser.add_argument("--modelname", default="",
- help="data root")
- parser.add_argument("--data_path", default=r"C:\Data_CD\LEVIR\256_2\test",
- help="data root")
- parser.add_argument("--device", default="cuda", help="training device")
- parser.add_argument("--out_path", default=r"C:\LangChao\b_detection\test\result\mynet", help="val root")
- args = parser.parse_args()
-
- return args
-
-
- def main(args):
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
- model = torch.load(args.ckpt_url, map_location=torch.device('cuda'))
- for m in model.modules():
- if isinstance(m, nn.Upsample):
- m.recompute_scale_factor = None
- model.eval()
- model.to(device)
- val_dataset = MyDataset(args.data_path)
- val_loader = torch.utils.data.DataLoader(val_dataset,
- batch_size=1,
- num_workers=0,
- pin_memory=False
- )
- confmat = evaluate(model, val_loader,
- device=device,
- num_classes=2, print_freq=500)
- val_info_print = str(confmat)
- print(val_info_print)
-
-
- if __name__ == '__main__':
- # 智算网络集群训练脚本自动化配置
- args = parse_args()
- # set_seed(args)
-
- main(args)
|