|
- import Levenshtein as Lev
- # import torch
- # from six.moves import xrange
-
-
- class Decoder(object):
- """
- Basic decoder class from which all other decoders inherit. Implements several
- helper functions. Subclasses should implement the decode() method.
- Arguments:
- labels (list): mapping from integers to characters.
- blank_index (int, optional): index for the blank '_' character. Defaults to 0.
- """
-
- def __init__(self, labels, blank_index=0):
- self.labels = labels
- self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
- self.blank_index = blank_index
- space_index = len(
- labels) # To prevent errors in decode, we add an out of bounds index for the space
- if ' ' in labels:
- space_index = labels.index(' ')
- self.space_index = space_index
- # print("space:",self.space_index)
-
- def wer(self, s1, s2):
- """
- Computes the Word Error Rate, defined as the edit distance between the
- two provided sentences after tokenizing to words.
- Arguments:
- s1 (string): space-separated sentence
- s2 (string): space-separated sentence
- """
-
- # build mapping of words to integers
- b = set(s1.split() + s2.split())
- word2char = dict(zip(b, range(len(b))))
-
- # map the words to a char array (Levenshtein packages only accepts
- # strings)
- w1 = [chr(word2char[w]) for w in s1.split()]
- w2 = [chr(word2char[w]) for w in s2.split()]
-
- return Lev.distance(''.join(w1), ''.join(w2))
-
- def cer(self, s1, s2):
- """
- Computes the Character Error Rate, defined as the edit distance.
- Arguments:
- s1 (string): space-separated sentence
- s2 (string): space-separated sentence
- """
- s1, s2, = s1.replace(' ', ''), s2.replace(' ', '')
- return Lev.distance(s1, s2)
-
- def decode(self, probs, sizes=None):
- """
- Given a matrix of character probabilities, returns the decoder's
- best guess of the transcription
- Arguments:
- probs: Tensor of character probabilities, where probs[c,t]
- is the probability of character c at time t
- sizes(optional): Size of each sequence in the mini-batch
- Returns:
- string: sequence of the model's best guess for the transcription
- """
- raise NotImplementedError
-
-
- # class GreedyDecoder(Decoder):
-
- # def __init__(self, labels, blank_index=0):
- # super(GreedyDecoder, self).__init__(labels, blank_index)
-
- # def convert_to_strings(self,
- # sequences,
- # sizes=None,
- # remove_repetitions=False,
- # return_offsets=False):
- # """Given a list of numeric sequences, returns the corresponding strings"""
- # strings = []
- # offsets = [] if return_offsets else None
- # for x in xrange(len(sequences)):
- # seq_len = sizes[x] if sizes is not None else len(sequences[x])
- # string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions)
- # strings.append([string]) # We only return one path
- # if return_offsets:
- # offsets.append([string_offsets])
- # if return_offsets:
- # return strings, offsets
- # else:
- # return strings
-
- # def process_string(self, sequence, size, remove_repetitions=False):
- # string = ''
- # offsets = []
- # for i in range(size):
- # char = self.int_to_char[sequence[i].item()]
- # if char != self.int_to_char[self.blank_index]:
- # # if this char is a repetition and remove_repetitions=true, then skip
- # if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i -
- # 1].item()]:
- # pass
- # elif char == self.labels[self.space_index]:
- # string += ' '
- # offsets.append(i)
- # else:
- # string = string + char
- # offsets.append(i)
- # return string, torch.tensor(offsets, dtype=torch.int)
-
- # def decode(self, probs, sizes=None):
- # """
- # Returns the argmax decoding given the probability matrix. Removes
- # repeated elements in the sequence, as well as blanks.
- # Arguments:
- # probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim
- # sizes(optional): Size of each sequence in the mini-batch
- # Returns:
- # strings: sequences of the model's best guess for the transcription on inputs
- # offsets: time step per character predicted
- # """
- # _, max_probs = torch.max(probs, 2)
- # strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0),
- # max_probs.size(1)),
- # sizes,
- # remove_repetitions=True,
- # return_offsets=True)
- # return strings, offsets
|