|
- # Copyright 2018 Dong-Hyun Lee, Kakao Brain.
-
- """ Utils Functions """
-
- import os
- import random
- import logging
- import numpy as np
- import torch
- import time
- import sys
-
-
- def format_time(seconds):
- days = int(seconds / 3600 / 24)
- seconds = seconds - days * 3600 * 24
- hours = int(seconds / 3600)
- seconds = seconds - hours * 3600
- minutes = int(seconds / 60)
- seconds = seconds - minutes * 60
- secondsf = int(seconds)
- seconds = seconds - secondsf
- millis = int(seconds * 1000)
-
- f = ''
- i = 1
- if days > 0:
- f += str(days) + 'D'
- i += 1
- if hours > 0 and i <= 2:
- f += str(hours) + 'h'
- i += 1
- if minutes > 0 and i <= 2:
- f += str(minutes) + 'm'
- i += 1
- if secondsf > 0 and i <= 2:
- f += str(secondsf) + 's'
- i += 1
- if millis > 0 and i <= 2:
- f += str(millis) + 'ms'
- i += 1
- if f == '':
- f = '0ms'
- return f
-
-
- TOTAL_BAR_LENGTH = 120.
- last_time = time.time()
- begin_time = last_time
-
- try:
- _, term_width = os.popen('stty size', 'r').read().split()
- except:
- term_width = 80
- term_width = int(term_width)
-
-
- def progress_bar(current, total, msg=None):
- global last_time, begin_time
- if current == 0:
- begin_time = time.time() # Reset for new bar.
-
- cur_len = int(TOTAL_BAR_LENGTH * current / total)
- rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
-
- sys.stdout.write(' [')
- for i in range(cur_len):
- sys.stdout.write('=')
- sys.stdout.write('>')
- for i in range(rest_len):
- sys.stdout.write('.')
- sys.stdout.write(']')
-
- cur_time = time.time()
- step_time = cur_time - last_time
- last_time = cur_time
- tot_time = cur_time - begin_time
-
- L = []
- L.append(' Step: %s' % format_time(step_time))
- L.append(' | Tot: %s' % format_time(tot_time))
- if msg:
- L.append(' | ' + msg)
-
- msg = ''.join(L)
- sys.stdout.write(msg)
- for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
- sys.stdout.write(' ')
-
- # Go back to the center of the bar.
- for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
- sys.stdout.write('\b')
- sys.stdout.write(' %d/%d ' % (current + 1, total))
-
- if current < total - 1:
- sys.stdout.write('\r')
- else:
- sys.stdout.write('\n')
- sys.stdout.flush()
-
-
- def set_seeds(seed):
- "set random seeds"
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
-
-
- def get_device():
- "get device (CPU or GPU)"
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- n_gpu = torch.cuda.device_count()
- print("%s (%d GPUs)" % (device, n_gpu))
- return device
-
-
- def split_last(x, shape):
- "split the last dimension to given shape"
- shape = list(shape)
- assert shape.count(-1) <= 1
- if -1 in shape:
- shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
- return x.view(*x.size()[:-1], *shape)
-
-
- def merge_last(x, n_dims):
- "merge the last n_dims to a dimension"
- s = x.size()
- assert n_dims > 1 and n_dims < len(s)
- return x.view(*s[:-n_dims], -1)
-
-
- def find_sublist(haystack, needle):
- """Return the index at which the sequence needle appears in the
- sequence haystack, or -1 if it is not found, using the Boyer-
- Moore-Horspool algorithm. The elements of needle and haystack must
- be hashable.
- https://codereview.stackexchange.com/questions/19627/finding-sub-list
- """
- h = len(haystack)
- n = len(needle)
- skip = {needle[i]: n - i - 1 for i in range(n - 1)}
- i = n - 1
- while i < h:
- for j in range(n):
- if haystack[i - j] != needle[-j - 1]:
- i += skip.get(haystack[i], n)
- break
- else:
- return i - n + 1
- return -1
-
-
- def truncate_tokens_pair(tokens_a, tokens_b, max_len):
- while True:
- if len(tokens_a) + len(tokens_b) <= max_len:
- break
- if len(tokens_a) > len(tokens_b):
- tokens_a.pop()
- else:
- tokens_b.pop()
-
-
- def get_random_word(vocab_words):
- i = random.randint(0, len(vocab_words) - 1)
- return vocab_words[i]
-
-
- def get_logger(name, log_path):
- "get logger"
- logger = logging.getLogger(name)
- fomatter = logging.Formatter(
- '[ %(levelname)s|%(filename)s:%(lineno)s] %(asctime)s > %(message)s')
-
- if not os.path.isfile(log_path):
- f = open(log_path, "w+")
-
- fileHandler = logging.FileHandler(log_path)
- fileHandler.setFormatter(fomatter)
- logger.addHandler(fileHandler)
-
- # streamHandler = logging.StreamHandler()
- # streamHandler.setFormatter(fomatter)
- # logger.addHandler(streamHandler)
-
- logger.setLevel(logging.DEBUG)
- return logger
-
-
- def my_cross_entropy(input, target):
- exp = torch.exp(input)
- tmp1 = exp.gather(1, target.unsqueeze(-1)).squeeze()
- tmp2 = exp.sum(1)
- softmax = tmp1 / tmp2
- log = -torch.log(softmax)
- return log
|