#3 上传文件至 ''

Open
1030514181 wants to merge 1 commits from 1030514181-patch-1 into master
  1. +158
    -0
      PrototypicalNetwork.py

+ 158
- 0
PrototypicalNetwork.py View File

@@ -0,0 +1,158 @@
# encoding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import jieba
import random
import torch.optim as optim
def createData():
text_list_pos = ["卷烟很好抽", "卷烟题材很好", "卷烟很好看", "价格很实惠", "利群有品味"]
text_list_neg = ["烟很垃圾", "卷烟是真的垃圾", "抽起来太难吃了", "又臭又长", "烟太让人失望了"]
test_pos = ["卷烟", "很", "好"]
test_neg = ["卷烟", "垃圾"]
words_pos = [[item for item in jieba.cut(text)] for text in text_list_pos]
words_neg = [[item for item in jieba.cut(text)] for text in text_list_neg]
words_all = []
for item in words_pos:
for key in item:
words_all.append(key)
for item in words_neg:
for key in item:
words_all.append(key)
vocab = list(set(words_all))
word2idx = {w: c for c, w in enumerate(vocab)}
idx_words_pos = [[word2idx[item] for item in text] for text in words_pos]
idx_words_neg = [[word2idx[item] for item in text] for text in words_neg]
idx_test_pos = [word2idx[item] for item in test_pos]
idx_test_neg = [word2idx[item] for item in test_neg]
return vocab, word2idx, idx_words_pos, idx_words_neg, idx_test_pos, idx_test_neg
def createOneHot(vocab, idx_words_pos, idx_words_neg, idx_test_pos, idx_test_neg):
input_dim = len(vocab)
features_pos = torch.zeros(size=[len(idx_words_pos), input_dim])
features_neg = torch.zeros(size=[len(idx_words_neg), input_dim])
for i in range(len(idx_words_pos)):
for j in idx_words_pos[i]:
features_pos[i, j] = 1.0
for i in range(len(idx_words_neg)):
for j in idx_words_neg[i]:
features_neg[i, j] = 1.0
features = torch.cat([features_pos, features_neg], dim=0)
labels = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
labels = torch.LongTensor(labels)
test_x_pos = torch.zeros(size=[1, input_dim])
test_x_neg = torch.zeros(size=[1, input_dim])
for item in idx_test_pos:
test_x_pos[0, item] = 1.0
for item in idx_test_neg:
test_x_neg[0, item] = 1.0
test_x = torch.cat([test_x_pos, test_x_neg], dim=0)
test_labels = torch.LongTensor([1, 0])
return features, labels, test_x, test_labels
def randomGenerate(features):
N = features.shape[0]
half_n = N // 2
support_input = torch.zeros(size=[6, features.shape[1]])
query_input = torch.zeros(size=[4, features.shape[1]])
postive_list = list(range(0, half_n))
negtive_list = list(range(half_n, N))
support_list_pos = random.sample(postive_list, 3)
support_list_neg = random.sample(negtive_list, 3)
query_list_pos = [item for item in postive_list if item not in support_list_pos]
query_list_neg = [item for item in negtive_list if item not in support_list_neg]
index = 0
for item in support_list_pos:
support_input[index, :] = features[item, :]
index += 1
for item in support_list_neg:
support_input[index, :] = features[item, :]
index += 1
index = 0
for item in query_list_pos:
query_input[index, :] = features[item, :]
index += 1
for item in query_list_neg:
query_input[index, :] = features[item, :]
index += 1
query_label = torch.LongTensor([1, 1, 0, 0])
return support_input, query_input, query_label
class fewModel(nn.Module):
def __init__(self, input_dim, hidden_dim, num_class):
super(fewModel, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_class = num_class
# 线性层进行编码
self.linear = nn.Linear(input_dim, hidden_dim)
def embedding(self, features):
result = self.linear(features)
return result
def forward(self, support_input, query_input):
support_embedding = self.embedding(support_input)
query_embedding = self.embedding(query_input)
support_size = support_embedding.shape[0]
every_class_num = support_size // self.num_class
class_meta_dict = {}
for i in range(0, self.num_class):
class_meta_dict[i] = torch.sum(support_embedding[i * every_class_num:(i + 1) * every_class_num, :],
dim=0) / every_class_num
class_meta_information = torch.zeros(size=[len(class_meta_dict), support_embedding.shape[1]])
for key, item in class_meta_dict.items():
class_meta_information[key, :] = class_meta_dict[key]
N_query = query_embedding.shape[0]
result = torch.zeros(size=[N_query, self.num_class])
for i in range(0, N_query):
temp_value = query_embedding[i].repeat(self.num_class, 1)
cosine_value = torch.cosine_similarity(class_meta_information, temp_value, dim=1)
result[i] = cosine_value
result = F.log_softmax(result, dim=1)
return result
hidden_dim = 4
n_class = 2
lr = 0.01
epochs = 10
vocab, word2idx, idx_words_pos, idx_words_neg, idx_test_pos, idx_test_neg = createData()
features, labels, test_x, test_labels = createOneHot(vocab, idx_words_pos, idx_words_neg, idx_test_pos, idx_test_neg)
model = fewModel(features.shape[1], hidden_dim, n_class)
optimer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
def train(epoch, support_input, query_input, query_label):
optimer.zero_grad()
output = model(support_input, query_input)
loss = F.nll_loss(output, query_label)
loss.backward()
optimer.step()
print("Epoch: {:04d}".format(epoch), "loss:{:.4f}".format(loss))
print(output)
if __name__ == '__main__':
vocab, word2idx, idx_words_pos, idx_words_neg, idx_test_pos, idx_test_neg = createData()
features, labels, test_x, test_labels = createOneHot(vocab, idx_words_pos, idx_words_neg, idx_test_pos, idx_test_neg)
for i in range(epochs):
# support_input, query_input, query_label = randomGenerate(features)
# train(i, support_input, query_input, query_label)
# print(support_input)
# print(query_input)
# print(query_label)
train(i, features, features, labels)
print(features)
print(features)
print(labels)
print()

Loading…
Cancel
Save