|
- import os
- import sys
- import json
- import pickle
- import random
-
- import torch
- from tqdm import tqdm
-
- # import matplotlib.pyplot as plt
-
-
- def read_split_data(root: str, val_rate: float = 0.1):
- random.seed(1) # 保证随机结果可复现
- assert os.path.exists(root), "dataset root: {} does not exist.".format(
- root)
-
- train_root = os.path.join(root, 'train')
- val_root = os.path.join(root, 'val')
- # 遍历文件夹,一个文件夹对应一个类别
- flower_class = [
- cla for cla in os.listdir(train_root)
- if os.path.isdir(os.path.join(train_root, cla))
- ]
- flower_class2 = [
- cla for cla in os.listdir(val_root)
- if os.path.isdir(os.path.join(val_root, cla))
- ]
- # 排序,保证顺序一致
- flower_class.sort()
- flower_class2.sort()
- # 生成类别名称以及对应的数字索引
- class_indices = dict((k, v) for v, k in enumerate(flower_class))
- class_indices2 = dict((k, v) for v, k in enumerate(flower_class2))
- json_str = json.dumps(dict(
- (val, key) for key, val in class_indices.items()),
- indent=4)
- with open('class_indices.json', 'w') as json_file:
- json_file.write(json_str)
-
- train_images_path = [] # 存储训练集的所有图片路径
- train_images_label = [] # 存储训练集图片对应索引信息
- val_images_path = [] # 存储验证集的所有图片路径
- val_images_label = [] # 存储验证集图片对应索引信息
- every_class_num = [] # 存储每个类别的样本总数
- supported = [".jpg", ".JPG", ".png", ".PNG", ".JPEG"] # 支持的文件后缀类型
- # 遍历每个文件夹下的文件
- list = []
- for cla in flower_class:
- cla_path = os.path.join(train_root, cla)
- cla_path2 = os.path.join(val_root, cla)
- # 遍历获取supported支持的所有文件路径
- images = [
- os.path.join(train_root, cla, i) for i in os.listdir(cla_path)
- if os.path.splitext(i)[-1] in supported
- ]
-
- images2 = [
- os.path.join(val_root, cla, i) for i in os.listdir(cla_path2)
- if os.path.splitext(i)[-1] in supported
- ]
-
- # 获取该类别对应的索引
- image_class = class_indices[cla]
- # 记录该类别的样本数量
- every_class_num.append(len(images) + len(images2))
- # 按比例随机采样验证样本
-
- for img_path in images:
- tmp = img_path.split("/")
- image_id = tmp[-1].split(".")[0]
- category_id = tmp[-2]
- dic = {"image_id": int(image_id), "category_id": int(category_id)}
- # print(dic)
- list.append(dic)
- train_images_path.append(img_path)
- train_images_label.append(image_class)
-
- for img_path in images2:
- tmp = img_path.split("/")
- image_id = tmp[-1].split(".")[0]
- category_id = tmp[-2]
- dic = {"image_id": int(image_id), "category_id": int(category_id)}
- # print(dic)
- list.append(dic)
- val_images_path.append(img_path)
- val_images_label.append(image_class)
-
-
- print("{} images were found in the dataset.".format(sum(every_class_num)))
- print("{} images for training.".format(len(train_images_path)))
- print("{} images for validation.".format(len(val_images_path)))
- with open("./phase1_set.json", "w") as f:
- json.dump(list, f)
-
- return train_images_path, train_images_label, val_images_path, val_images_label
-
- '''
- def plot_data_loader_image(data_loader):
- batch_size = data_loader.batch_size
- plot_num = min(batch_size, 4)
-
- json_path = './class_indices.json'
- assert os.path.exists(json_path), json_path + " does not exist."
- json_file = open(json_path, 'r')
- class_indices = json.load(json_file)
-
- for data in data_loader:
- images, labels = data
- for i in range(plot_num):
- # [C, H, W] -> [H, W, C]
- img = images[i].numpy().transpose(1, 2, 0)
- # 反Normalize操作
- img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
- label = labels[i].item()
- plt.subplot(1, plot_num, i + 1)
- plt.xlabel(class_indices[str(label)])
- plt.xticks([]) # 去掉x轴的刻度
- plt.yticks([]) # 去掉y轴的刻度
- plt.imshow(img.astype('uint8'))
- plt.show()
- '''
-
- def write_pickle(list_info: list, file_name: str):
- with open(file_name, 'wb') as f:
- pickle.dump(list_info, f)
-
-
- def read_pickle(file_name: str) -> list:
- with open(file_name, 'rb') as f:
- info_list = pickle.load(f)
- return info_list
-
-
- def train_one_epoch(model, optimizer, data_loader, device, epoch):
- model.train()
- loss_function = torch.nn.CrossEntropyLoss()
- accu_loss = torch.zeros(1).to(device) # 累计损失
- accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
- optimizer.zero_grad()
-
- sample_num = 0
- #data_loader = tqdm(data_loader, file=sys.stdout)
- for step, data in enumerate(data_loader):
- images, labels = data
- sample_num += images.shape[0]
-
- pred = model(images.to(device))
- pred_classes = torch.max(pred, dim=1)[1]
- accu_num += torch.eq(pred_classes, labels.to(device)).sum()
-
- loss = loss_function(pred, labels.to(device))
- loss.backward()
- accu_loss += loss.detach()
-
- data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(
- epoch,
- accu_loss.item() / (step + 1),
- accu_num.item() / sample_num)
-
- if not torch.isfinite(loss):
- print('WARNING: non-finite loss, ending training ', loss)
- sys.exit(1)
-
- optimizer.step()
- optimizer.zero_grad()
-
- return accu_loss.item() / (step + 1), accu_num.item() / sample_num
-
-
- def train_one_epoch2(model, optimizer, data_loader, device, epoch):
- model.train()
- loss_function = torch.nn.CrossEntropyLoss()
- accu_loss = torch.zeros(1).to(device) # 累计损失
- accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
- optimizer.zero_grad()
-
- sample_num = 0
- #data_loader = tqdm(data_loader, file=sys.stdout)
-
- for step, data in enumerate(data_loader):
- images, labels = data
- sample_num += images.shape[0]
-
- images = images.to(device)
- labels = labels.to(device)
-
- alpha = 0.001
- ites = 4
- eps = 0.02
-
- pgd_images = PGD_attack(model, images, labels, eps, ites, alpha,
- device)
-
- pred = model(pgd_images)
- pred_classes = torch.max(pred, dim=1)[1]
- accu_num += torch.eq(pred_classes, labels.to(device)).sum()
-
- loss = loss_function(pred, labels.to(device))
- loss.backward()
- accu_loss += loss.detach()
- '''
- data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(
- epoch,
- accu_loss.item() / (step + 1),
- accu_num.item() / sample_num)
- '''
- if not torch.isfinite(loss):
- print('WARNING: non-finite loss, ending training ', loss)
- sys.exit(1)
-
- optimizer.step()
- optimizer.zero_grad()
-
- return accu_loss.item() / ((step + 1)), accu_num.item() / (sample_num)
-
-
- def PGD_attack(model, input_data, labels, eps, iters, alpha, device):
- images = input_data
- labels = labels
- loss = torch.nn.CrossEntropyLoss()
-
- ori_images = images.data
-
- for i in range(iters):
- images.requires_grad = True
- outputs = model(images)
- model.zero_grad()
- cost = loss(outputs, labels).to(device)
- cost.backward()
-
- adv_images = images + alpha * images.grad.sign()
- eta = torch.clamp(adv_images - ori_images, min=-eps, max=eps)
- images = (ori_images + eta).detach()
-
- return images
-
-
- @torch.no_grad()
- def evaluate(model, data_loader, device, epoch):
- loss_function = torch.nn.CrossEntropyLoss()
-
- model.eval()
-
- accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
- accu_loss = torch.zeros(1).to(device) # 累计损失
-
- sample_num = 0
- #data_loader = tqdm(data_loader, file=sys.stdout)
- for step, data in enumerate(data_loader):
- images, labels = data
- sample_num += images.shape[0]
-
- pred = model(images.to(device))
- pred_classes = torch.max(pred, dim=1)[1]
- accu_num += torch.eq(pred_classes, labels.to(device)).sum()
-
- loss = loss_function(pred, labels.to(device))
- accu_loss += loss
- '''
- data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(
- epoch,
- accu_loss.item() / (step + 1),
- accu_num.item() / sample_num)
- '''
-
- return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|