|
- Classifying Names with a Character-level Spiking LSTM
- ==============================================================================
- Authors: `LiutaoYu <https://github.com/LiutaoYu>`_, `fangwei123456 <https://github.com/fangwei123456>`_
-
- This tutorial applies a Spiking LSTM to reproduce the PyTorch official tutorial `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_.
- Please make sure that you have read the original tutorial and corresponding codes before proceeding.
- Specifically, we will train a spiking LSTM to classify surnames into different languages according to their spelling, based on a dataset consisting of several thousands of surnames from 18 languages of origin.
- The integrated script can be found here ( `activation_based/examples/spiking_lstm_text.py <https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/examples/spiking_lstm_text.py>`_).
-
- Preparing the data
- ----------------------------
- First of all, we need to download and preprocess the data as the original tutorial, which produces a dictionary ``{language: [names ...]}`` .
- Then, we split the dataset into a training set and a testing set (the ratio is 4:1), i.e., ``category_lines_train`` and ``category_lines_test`` .
- Here, we emphasize several important variables: ``all_categories`` is the list of 18 languages, the length of which is ``n_categories=18``;
- ``n_letters=58`` is the number of all characters composing the surnames.
-
- .. code-block:: python
-
- # split the data into training set and testing set
- numExamplesPerCategory = []
- category_lines_train = {}
- category_lines_test = {}
- testNumtot = 0
- for c, names in category_lines.items():
- category_lines_train[c] = names[:int(len(names)*0.8)]
- category_lines_test[c] = names[int(len(names)*0.8):]
- numExamplesPerCategory.append([len(category_lines[c]), len(category_lines_train[c]), len(category_lines_test[c])])
- testNumtot += len(category_lines_test[c])
-
- In addition, we rephrase the function ``randomTrainingExample()`` to function ``randomPair(sampleSource)`` for different conditions.
- Here we adopt function ``lineToTensor()`` and ``randomChoice()`` from the original tutorial.
- ``lineToTensor()`` converts a surname into a one-hot tensor, and ``randomChoice()`` randomly choose a sample from the dataset.
-
- .. code-block:: python
-
- # Preparing [x, y] pair
- def randomPair(sampleSource):
- """
- Args:
- sampleSource: 'train', 'test', 'all'
- Returns:
- category, line, category_tensor, line_tensor
- """
- category = randomChoice(all_categories)
- if sampleSource == 'train':
- line = randomChoice(category_lines_train[category])
- elif sampleSource == 'test':
- line = randomChoice(category_lines_test[category])
- elif sampleSource == 'all':
- line = randomChoice(category_lines[category])
- category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.float)
- line_tensor = lineToTensor(line)
- return category, line, category_tensor, line_tensor
-
- Building a spiking LSTM network
- --------------------------------------
- We build a spiking LSTM based on the ``rnn`` module from `spikingjelly <https://github.com/fangwei123456/spikingjelly>`_ .
- The theory can be found in the paper `Long Short-Term Memory Spiking Networks and Their Applications <https://arxiv.org/abs/2007.04779>`_ .
- The amounts of neurons in the input layer, hidden layer and output layer are ``n_letters``, ``n_hidden`` and ``n_categories`` respectively.
- We add a fully connected layer to the output layer, and use ``softmax`` function to obtain the classification probability.
-
- .. code-block:: python
-
- from spikingjelly.activation_based import rnn
- n_hidden = 256
-
- class Net(nn.Module):
- def __init__(self, n_letters, n_hidden, n_categories):
- super().__init__()
- self.n_input = n_letters
- self.n_hidden = n_hidden
- self.n_out = n_categories
- self.lstm = rnn.SpikingLSTM(self.n_input, self.n_hidden, 1)
- self.fc = nn.Linear(self.n_hidden, self.n_out)
-
- def forward(self, x):
- x, _ = self.lstm(x)
- output = self.fc(x[-1])
- output = F.softmax(output, dim=1)
- return output
-
- Training the network
- ---------------------------------------
- First of all, we initialize the ``net`` , and define parameters like ``TRAIN_EPISODES`` and ``learning_rate``.
- Here we adopt ``mse_loss`` and ``Adam`` optimizer to train the network.
- The process of one training epoch is as follows:
- 1) randomly choose a sample from the training set, and convert the input and label into tensors;
- 2) feed the input to the network, and obtain the classification probability through the forward process;
- 3) calculate the network loss through ``mse_loss``;
- 4) back-propagate the gradients, and update the training parameters;
- 5) judge whether the prediction is correct or not, and count the number of correct predictions to obtain the training accuracy every ``plot_every`` epochs;
- 6) evaluate the network on the testing set every ``plot_every`` epochs to obtain the testing accuracy.
- During training, we record the history of network loss ``avg_losses`` , training accuracy ``accuracy_rec`` and testing accuracy ``test_accu_rec`` , to observe the training process.
- After training, we will save the final state of the network for testing, and also some variables for later analyses.
-
- .. code-block:: python
-
- # IF_TRAIN = 1
- TRAIN_EPISODES = 1000000
- plot_every = 1000
- learning_rate = 1e-4
-
- net = Net(n_letters, n_hidden, n_categories)
- optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
-
- print('Training...')
- current_loss = 0
- correct_num = 0
- avg_losses = []
- accuracy_rec = []
- test_accu_rec = []
- start = time.time()
- for epoch in range(1, TRAIN_EPISODES+1):
- net.train()
- category, line, category_tensor, line_tensor = randomPair('train')
- label_one_hot = F.one_hot(category_tensor.to(int), n_categories).float()
-
- optimizer.zero_grad()
- out_prob_log = net(line_tensor)
- loss = F.mse_loss(out_prob_log, label_one_hot)
- loss.backward()
- optimizer.step()
-
- current_loss += loss.data.item()
-
- guess, _ = categoryFromOutput(out_prob_log.data)
- if guess == category:
- correct_num += 1
-
- # Add current loss avg to list of losses
- if epoch % plot_every == 0:
- avg_losses.append(current_loss / plot_every)
- accuracy_rec.append(correct_num / plot_every)
- current_loss = 0
- correct_num = 0
-
- # evaluate the network on the testing set every ``plot_every`` epochs to obtain the testing accuracy
- if epoch % plot_every == 0: # int(TRAIN_EPISODES/1000)
- net.eval()
- with torch.no_grad():
- numCorrect = 0
- for i in range(n_categories):
- category = all_categories[i]
- for tname in category_lines_test[category]:
- output = net(lineToTensor(tname))
- guess, _ = categoryFromOutput(output.data)
- if guess == category:
- numCorrect += 1
- test_accu = numCorrect / testNumtot
- test_accu_rec.append(test_accu)
- print('Epoch %d %d%% (%s); Avg_loss %.4f; Train accuracy %.4f; Test accuracy %.4f' % (
- epoch, epoch / TRAIN_EPISODES * 100, timeSince(start), avg_losses[-1], accuracy_rec[-1], test_accu))
-
- torch.save(net, 'char_rnn_classification.pth')
- np.save('avg_losses.npy', np.array(avg_losses))
- np.save('accuracy_rec.npy', np.array(accuracy_rec))
- np.save('test_accu_rec.npy', np.array(test_accu_rec))
- np.save('category_lines_train.npy', category_lines_train, allow_pickle=True)
- np.save('category_lines_test.npy', category_lines_test, allow_pickle=True)
- # x = np.load('category_lines_test.npy', allow_pickle=True) # way to loading the data
- # xdict = x.item()
-
- plt.figure()
- plt.subplot(311)
- plt.plot(avg_losses)
- plt.title('Average loss')
- plt.subplot(312)
- plt.plot(accuracy_rec)
- plt.title('Train accuracy')
- plt.subplot(313)
- plt.plot(test_accu_rec)
- plt.title('Test accuracy')
- plt.xlabel('Epoch (*1000)')
- plt.subplots_adjust(hspace=0.6)
- plt.savefig('TrainingProcess.svg')
- plt.close()
-
- We will observe the following results when executing ``%run ./spiking_lstm_text.py`` in Python Console with ``IF_TRAIN = 1`` .
-
- .. code-block:: shell
-
- Backend Qt5Agg is interactive backend. Turning interactive mode on.
- Training...
- Epoch 1000 0% (0m 18s); Avg_loss 0.0525; Train accuracy 0.0830; Test accuracy 0.0806
- Epoch 2000 0% (0m 37s); Avg_loss 0.0514; Train accuracy 0.1470; Test accuracy 0.1930
- Epoch 3000 0% (0m 55s); Avg_loss 0.0503; Train accuracy 0.1650; Test accuracy 0.0537
- Epoch 4000 0% (1m 14s); Avg_loss 0.0494; Train accuracy 0.1920; Test accuracy 0.0938
- ...
- ...
- Epoch 998000 99% (318m 54s); Avg_loss 0.0063; Train accuracy 0.9300; Test accuracy 0.5036
- Epoch 999000 99% (319m 14s); Avg_loss 0.0056; Train accuracy 0.9380; Test accuracy 0.5004
- Epoch 1000000 100% (319m 33s); Avg_loss 0.0055; Train accuracy 0.9340; Test accuracy 0.5118
-
- The following picture shows how average loss ``avg_losses`` , training accuracy ``accuracy_rec`` and testing accuracy ``test_accu_rec`` improve with training.
-
- .. image:: ../_static/tutorials/activation_based/\9_spikingLSTM_text/TrainingProcess.*
- :width: 100%
-
- Testing the network
- ---------------------------
- We first load the well-trained network, and then conduct the following tests:
- 1) calculate the testing accuracy of the final network;
- 2) predict the language origin of the surnames provided by the user;
- 3) calculate the confusion matrix, indicating for every actual language (rows) which language the network guesses (columns).
-
- .. code-block:: python
-
- # IF_TRAIN = 0
- print('Testing...')
-
- net = torch.load('char_rnn_classification.pth')
-
- # calculate the testing accuracy of the final network
- print('Calculating testing accuracy...')
- numCorrect = 0
- for i in range(n_categories):
- category = all_categories[i]
- for tname in category_lines_test[category]:
- output = net(lineToTensor(tname))
- guess, _ = categoryFromOutput(output.data)
- if guess == category:
- numCorrect += 1
- test_accu = numCorrect / testNumtot
- print('Test accuracy: {:.3f}, Random guess: {:.3f}'.format(test_accu, 1/n_categories))
-
- # predict the language origin of the surnames provided by the user
- n_predictions = 3
- for j in range(3):
- first_name = input('Please input a surname to predict its language origin:')
- print('\n> %s' % first_name)
- output = net(lineToTensor(first_name))
-
- # Get top N categories
- topv, topi = output.topk(n_predictions, 1, True)
- predictions = []
-
- for i in range(n_predictions):
- value = topv[0][i].item()
- category_index = topi[0][i].item()
- print('(%.2f) %s' % (value, all_categories[category_index]))
- predictions.append([value, all_categories[category_index]])
-
- # calculate the confusion matrix
- print('Calculating confusion matrix...')
- confusion = torch.zeros(n_categories, n_categories)
- n_confusion = 10000
-
- # Keep track of correct guesses in a confusion matrix
- for i in range(n_confusion):
- category, line, category_tensor, line_tensor = randomPair('all')
- output = net(line_tensor)
- guess, guess_i = categoryFromOutput(output.data)
- category_i = all_categories.index(category)
- confusion[category_i][guess_i] += 1
-
- confusion = confusion / confusion.sum(1)
- np.save('confusion.npy', confusion)
-
- # Set up plot
- fig = plt.figure(figsize=(10, 8))
- ax = fig.add_subplot(111)
- cax = ax.matshow(confusion.numpy())
- fig.colorbar(cax)
- # Set up axes
- ax.set_xticklabels([''] + all_categories, rotation=90)
- ax.set_yticklabels([''] + all_categories)
- # Force label at every tick
- ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
- ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
- # sphinx_gallery_thumbnail_number = 2
- plt.show()
- plt.savefig('ConfusionMatrix.svg')
- plt.close()
-
- We will observe the following results when executing ``%run ./spiking_lstm_text.py`` in Python Console with ``IF_TRAIN = 0`` .
-
- .. code-block:: shell
-
- Testing...
- Calculating testing accuracy...
- Test accuracy: 0.512, Random guess: 0.056
- Please input a surname to predict its language origin:> YU
- > YU
- (0.18) Scottish
- (0.12) English
- (0.11) Italian
- Please input a surname to predict its language origin:> Yu
- > Yu
- (0.63) Chinese
- (0.23) Korean
- (0.07) Vietnamese
- Please input a surname to predict its language origin:> Zou
- > Zou
- (1.00) Chinese
- (0.00) Arabic
- (0.00) Polish
- Calculating confusion matrix...
-
- The following picture exhibits the confusion matrix, of which a brighter diagonal element indicates better prediction, and thus less confusion, such as Arabic and Greek.
- However, some languages are prone to confusion, such as Korean and Chinese, English and Scottish.
-
- .. image:: ../_static/tutorials/activation_based/\9_spikingLSTM_text/ConfusionMatrix.*
- :width: 100%
|