|
- import torch
- import torch.nn.functional as F
- from torch_geometric.nn import RGCNConv
- import time
- import argparse
- import numpy as np
- import json
- import torch.nn as nn
- import torch
- from torch_geometric.logging import log
- import torch.nn.functional as F
- from torch_geometric.loader import NeighborLoader
- from sklearn.metrics import average_precision_score
- from tqdm import tqdm
-
- class RGCN(torch.nn.Module):
- def __init__(self, in_channels, hidden_channels, out_channels, n_layers, num_relations, n_bases, dropout):
- super().__init__()
- self.dropout = dropout
- self.convs = torch.nn.ModuleList()
- self.relu = F.relu
- self.convs.append(RGCNConv(in_channels, hidden_channels, num_relations, num_bases=n_bases))
- for i in range(n_layers - 2):
- self.convs.append(RGCNConv(hidden_channels, hidden_channels, num_relations, num_bases=n_bases))
- self.convs.append(RGCNConv(hidden_channels, out_channels, num_relations, num_bases=n_bases))
-
- def forward(self, x, edge_index, edge_type,perturb=None):
- if perturb is not None:
- x = x + perturb
- for i, conv in enumerate(self.convs):
- x = conv(x, edge_index, edge_type)
- if i < len(self.convs) - 1:
- x = x.relu_()
- x = F.dropout(x, p=self.dropout, training=self.training) # 这里dropout貌似固定了
- return x
-
- import zipfile
- import os
-
- # 压缩文件路径
- zip_path='/tmp/dataset/pyg_data.zip'
- # 文件存储路径
- save_path = 'train'
- # 读取压缩文件
- file=zipfile.ZipFile(zip_path)
- # 解压文件
- print('开始解压...')
- file.extractall(save_path)
- print('解压结束。')
-
- class FLAG:
- def __init__(self,num_nodes,in_channels, loss_func, model,optimizer,y, m=3, step_size=1e-3):
- self.num_nodes = num_nodes
- self.in_channels = in_channels
- self.loss_func = loss_func
- self.model= model
- self.optimizer = optimizer
- self.y = y
- self.m = m
- self.step_size = step_size
- def __call__(self,forward):
- self.model.train()
- self.optimizer.zero_grad()
- device = self.y.device
- perturb = torch.FloatTensor(
- *(self.num_nodes, self.in_channels)).uniform_(-self.step_size, self.step_size).to(device)
- perturb.requires_grad_()
- out = forward(perturb)
- loss = self.loss_func(out, self.y)
- loss /= self.m
- for _ in range(self.m - 1):
- loss.backward()
- perturb_data = perturb.detach() + self.step_size * torch.sign(perturb.grad.detach())
- perturb.data = perturb_data.data
- perturb.grad[:] = 0
- out = forward(perturb)
- loss = self.loss_func(out, self.y)
- loss /= self.m
- loss.backward()
- self.optimizer.step()
- return loss, out
-
- parser = argparse.ArgumentParser(description='RGCN_pyg')
- parser.add_argument("--dropout", type=float, default=0.5,
- help="dropout probability")
- parser.add_argument('--dataset', type=str, default='/dataset/pyg_data/pyg_data.pt')
- parser.add_argument('--labeled-class', type=str, default='item')
- parser.add_argument("--batch-size", type=int, default=256, # 200
- help="Mini-batch size. If -1, use full graph training.")
- parser.add_argument("--fanout", type=int, default=100, # 150
- help="Fan-out of neighbor sampling.")
- parser.add_argument("--n-layers", type=int, default=3, # 3
- help="number of propagation rounds")
- parser.add_argument("--h-dim", type=int, default=64, # 256
- help="number of hidden units")
- parser.add_argument("--in-dim", type=int, default=256,
- help="number of hidden units")
- parser.add_argument("--n-bases", type=int, default=8,
- help="number of filter weight matrices, default: -1 [use all]")
- parser.add_argument("--early_stopping", type=int, default=6) # 6
- parser.add_argument("--n-epoch", type=int, default=100) # 100
- parser.add_argument("--test-file", type=str, default="/tmp/dataset/pyg_data/icdm2022_session1_test_ids.csv")
- parser.add_argument("--json-file", type=str, default="/tmp/output/pyg_pred.json")
- # parser.add_argument("--record-file", type=str, default="record.txt")
- parser.add_argument("--lr", type=float, default=0.001)
- parser.add_argument("--l2norm", type=float, default=5e-4, help="l2 norm coef")
- parser.add_argument("--model-path", type=str, default='/tmp/output/session1_rgcn_pyg.pt',
- help='path for save the model')
- parser.add_argument("--flag", type=bool, default=True,help='augement FLAG')
- args = parser.parse_args([])
-
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- hgraph = torch.load('train/pyg_data/pyg_data.pt')
- labeled_class = args.labeled_class
- # train, valid, test idx
- train_idx = hgraph[labeled_class].pop('train_idx')
- val_idx = hgraph[labeled_class].pop('val_idx')
-
- num_relations = len(hgraph.edge_types)
- model = RGCN(in_channels=args.in_dim, hidden_channels=args.h_dim, out_channels=2,
- n_layers=args.n_layers, num_relations=num_relations, n_bases=args.n_bases, dropout=args.dropout).to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm)
-
- def val(epoch, model, val_loader, labeled_class, device):
- model.eval()
- total_loss = total_correct = total_examples = 0
- y_pred = []
- y_true = []
- for batch in val_loader:
- batch_size = batch[labeled_class].batch_size
- y = batch[labeled_class].y[:batch_size].to(device)
- start = 0
- for ntype in batch.node_types:
- if ntype == labeled_class:
- break
- start += batch[ntype].num_nodes
-
- batch = batch.to_homogeneous()
-
- y_hat = model(batch.x.to(device), batch.edge_index.to(device), batch.edge_type.to(device))[
- start:start + batch_size]
- loss = F.cross_entropy(y_hat, y)
- y_pred.append(F.softmax(y_hat, dim=1)[:, 1].detach().cpu())
- y_true.append(y.cpu())
- total_loss += float(loss) * batch_size
- total_correct += int((y_hat.argmax(dim=-1) == y).sum())
- total_examples += batch_size
-
- ap_score = average_precision_score(torch.hstack(y_true).numpy(), torch.hstack(y_pred).numpy())
-
- return total_loss / total_examples, total_correct / total_examples, ap_score
-
-
- def train (epoch,model,train_loader,optimizer,labeled_class,device):
- model.train()
- total_loss=total_correct=total_examples=0
- y_preds=[]
- y_trues=[]
- for i, batch in enumerate(train_loader):
- optimizer.zero_grad()
- start = 0
- for ntype in batch.node_types:
- if ntype == labeled_class:
- break
- start += batch[ntype].num_nodes
- batch_size=batch[labeled_class].batch_size
- y=batch[labeled_class].y[:batch_size].to(device)
- batch=batch.to_homogeneous()
-
-
- if args.flag==True:
- flag=FLAG(batch.x.shape[0],args.in_dim, nn.CrossEntropyLoss(), model,optimizer,y)
- forward= lambda perturb: model(batch.x.to(device), batch.edge_index.to(device), batch.edge_type.to(device), perturb)[
- start:start + batch_size]
- loss, y_hat= flag(forward)
- loss = loss.item()
- else:
- y_hat=model(batch.x.to(device),batch.edge_index.to(device),batch.edge_type.to(device))[start:start+batch_size]
- loss=F.cross_entropy(y_hat,y)
- loss.backward()
- optimizer.step()
-
-
- y_preds.append(F.softmax(y_hat, dim=1)[:, 1].detach().cpu())
- y_trues.append(y.cpu())
- total_loss += float(loss) * batch_size
- total_correct += int((y_hat.argmax(dim=-1) == y).sum())
- total_examples += batch_size
-
- # 自己加
- y_pred = F.softmax(y_hat, dim=1)[:, 1].detach().cpu()
- y_true = y.detach().cpu()
- train_ap = average_precision_score(y_true.numpy(), y_pred.numpy())
- train_acc = torch.sum(y_hat.argmax(dim=1) == y).item() / len(y)
- if i%20==0:
- print(
- "Epoch {:03d} | Batch {:03d} | Train AP: {:.4f} | Train Acc: {:.4f} ".
- format(epoch, i, train_ap, train_acc))
- log(Epoch=epoch, Loss=i,Train=train_ap)
- # pbar.update(batch_size)
- # pbar.close()
- ap_score = average_precision_score(torch.hstack(y_trues).numpy(), torch.hstack(y_preds).numpy())
-
- return total_loss / total_examples, total_correct / total_examples, ap_score
-
- def main(args):
- # data loader
- train_loader = NeighborLoader(hgraph, input_nodes=(labeled_class, train_idx),
- num_neighbors=[args.fanout] * args.n_layers,
-
- shuffle=True, batch_size=args.batch_size)
- val_loader = NeighborLoader(hgraph, input_nodes=(labeled_class, val_idx),
- num_neighbors=[args.fanout] * args.n_layers,
- shuffle=False, batch_size=args.batch_size)
- num_relations = len(hgraph.edge_types)
- model = RGCN(in_channels=args.in_dim, hidden_channels=args.h_dim, out_channels=2,
- n_layers=args.n_layers, num_relations=num_relations, n_bases=args.n_bases, dropout=args.dropout).to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm)
- # training loop
- print("start training...")
- val_ap_list = []
- ave_val_ap = 0
- best_score = 0
- es = 0 # early stop
- for epoch in range(1, args.n_epoch + 1):
- #training
- t0 = time.time()
- train_loss, train_acc, train_ap = train(epoch, model, train_loader, optimizer, labeled_class, device)
- print(f'Train: Epoch {epoch:02d}, Loss: {train_loss:.4f},'
- f' Acc: {train_acc:.4f}, AP_Score: {train_ap:.4f}, Time: {time.time() - t0:.4f}')
- #val
- t0 = time.time()
- val_loss, val_acc, val_ap = val(epoch, model, val_loader, labeled_class, device)
- print(f'Val: Epoch: {epoch:02d}, Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, AP_Score: {val_ap:.4f}, Time: {time.time() - t0:.4f}')
- log(Epoch=epoch,Train=train_ap, Val=val_ap)
- if val_ap > best_score:
- best_score = val_ap
- es = 0
- else:
- es+=1
- print("Counter {} of {}".format(es, args.early_stopping))
- if es > args.early_stopping:
- print("Early stopping with best_valid ap: ", best_score)
- break
- val_ap_list.append(float(val_ap))
- ave_val_ap = np.average(val_ap_list)
-
- main(args)
|