|
- import torch
- import os
- import logging
- from argparse import ArgumentParser
- import sys
- import time
- import math
- import torch.nn as nn
- import torch.nn.init as init
-
-
- def parse_args():
- parser = ArgumentParser(description='PyTorch/torchtext NLI Baseline')
- parser.add_argument('--dataset', '-d', type=str, default='multinli')
- parser.add_argument('--model', '-m', type=str, default='Causformer_main')
- parser.add_argument('--gpu', type=int, default=0)
- parser.add_argument('--batch_size', type=int, default=64)
- parser.add_argument('--embed_dim', type=int, default=768)
- parser.add_argument('--d_hidden', type=int, default=200)
- parser.add_argument('--dp_ratio', type=int, default=0.2)
- parser.add_argument('--epochs', type=int, default=1000)
- parser.add_argument('--lr', type=float, default=0.00002)
- parser.add_argument('--combine', type=str, default='cat')
- parser.add_argument('--parallel', type=bool, default=True)
- parser.add_argument('--results_dir', type=str, default='results')
- return check_args(parser.parse_args())
-
-
- """checking arguments"""
-
-
- def check_args(args):
- # --result_dir
- check_folder(os.path.join(args.results_dir, args.model, args.dataset))
-
- # --epoch
- try:
- assert args.epochs >= 1
- except:
- print('number of epochs must be larger than or equal to one')
-
- # --batch_size
- try:
- assert args.batch_size >= 1
- except:
- print('batch size must be larger than or equal to one')
- return args
-
-
- def get_device():
- "get device (CPU or GPU)"
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- n_gpu = torch.cuda.device_count()
- print("%s (%d GPUs)" % (device, n_gpu))
- return device
-
-
- def makedirs(name):
- """helper function for python 2 and 3 to call os.makedirs()
- avoiding an error if the directory to be created already exists"""
-
- import os, errno
-
- try:
- os.makedirs(name)
- except OSError as ex:
- if ex.errno == errno.EEXIST and os.path.isdir(name):
- # ignore existing directory
- pass
- else:
- # a different error happened
- raise
-
-
- def check_folder(log_dir):
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
- return log_dir
-
-
- def get_logger(args, phase):
- logging.basicConfig(level=logging.INFO,
- filename="{}/{}/{}/{}.log".format(args.results_dir, args.model, args.dataset, phase),
- format='%(asctime)s - %(message)s',
- datefmt='%d-%b-%y %H:%M:%S')
- return logging.getLogger(phase)
-
-
- def get_mean_and_std(dataset):
- '''Compute the mean and std value of dataset.'''
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
- mean = torch.zeros(3)
- std = torch.zeros(3)
- print('==> Computing mean and std..')
- for inputs, targets in dataloader:
- for i in range(3):
- mean[i] += inputs[:, i, :, :].mean()
- std[i] += inputs[:, i, :, :].std()
- mean.div_(len(dataset))
- std.div_(len(dataset))
- return mean, std
-
-
- def init_params(net):
- '''Init layer parameters.'''
- for m in net.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal(m.weight, mode='fan_out')
- if m.bias:
- init.constant(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- init.constant(m.weight, 1)
- init.constant(m.bias, 0)
- elif isinstance(m, nn.Linear):
- init.normal(m.weight, std=1e-3)
- if m.bias:
- init.constant(m.bias, 0)
-
-
- try:
- _, term_width = os.popen('stty size', 'r').read().split()
- except:
- term_width = 80
- term_width = int(term_width)
-
- TOTAL_BAR_LENGTH = 65.
- last_time = time.time()
- begin_time = last_time
-
-
- def progress_bar(current, total, msg=None):
- global last_time, begin_time
- if current == 0:
- begin_time = time.time() # Reset for new bar.
-
- cur_len = int(TOTAL_BAR_LENGTH * current / total)
- rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
-
- sys.stdout.write(' [')
- for i in range(cur_len):
- sys.stdout.write('=')
- sys.stdout.write('>')
- for i in range(rest_len):
- sys.stdout.write('.')
- sys.stdout.write(']')
-
- cur_time = time.time()
- step_time = cur_time - last_time
- last_time = cur_time
- tot_time = cur_time - begin_time
-
- L = []
- L.append(' Step: %s' % format_time(step_time))
- L.append(' | Tot: %s' % format_time(tot_time))
- if msg:
- L.append(' | ' + msg)
-
- msg = ''.join(L)
- sys.stdout.write(msg)
- for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
- sys.stdout.write(' ')
-
- # Go back to the center of the bar.
- for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
- sys.stdout.write('\b')
- sys.stdout.write(' %d/%d ' % (current + 1, total))
-
- if current < total - 1:
- sys.stdout.write('\r')
- else:
- sys.stdout.write('\n')
- sys.stdout.flush()
-
-
- def format_time(seconds):
- days = int(seconds / 3600 / 24)
- seconds = seconds - days * 3600 * 24
- hours = int(seconds / 3600)
- seconds = seconds - hours * 3600
- minutes = int(seconds / 60)
- seconds = seconds - minutes * 60
- secondsf = int(seconds)
- seconds = seconds - secondsf
- millis = int(seconds * 1000)
-
- f = ''
- i = 1
- if days > 0:
- f += str(days) + 'D'
- i += 1
- if hours > 0 and i <= 2:
- f += str(hours) + 'h'
- i += 1
- if minutes > 0 and i <= 2:
- f += str(minutes) + 'm'
- i += 1
- if secondsf > 0 and i <= 2:
- f += str(secondsf) + 's'
- i += 1
- if millis > 0 and i <= 2:
- f += str(millis) + 'ms'
- i += 1
- if f == '':
- f = '0ms'
- return f
|