已解决
PyG-GCN-Cora(在Cora数据集上应用GCN做节点分类)
来自网友在路上 160860提问 提问时间:2023-09-20 05:26:39阅读次数: 60
最佳答案 问答题库608位专家为你答疑解惑
文章目录
- model.py
- main.py
- 参数设置
- 注意事项
- 运行图
model.py
import torch.nn as nn
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class gcn_cls(nn.Module):def __init__(self,in_dim,hid_dim,out_dim,dropout_size=0.5):super(gcn_cls,self).__init__()self.conv1 = GCNConv(in_dim,hid_dim)self.conv2 = GCNConv(hid_dim,hid_dim)self.fc = nn.Linear(hid_dim,out_dim)self.relu = nn.ReLU()self.dropout_size = dropout_sizedef forward(self,x,edge_index):x = self.conv1(x,edge_index)x = F.dropout(x,p=self.dropout_size,training=self.training)x = self.relu(x)x = self.conv2(x,edge_index)x = self.relu(x)x = self.fc(x)return x
main.py
import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from model import gcn_cls
import torch.optim as optim
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset[0])
cora_data = dataset[0]epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7net = gcn_cls(cora_data.x.shape[1],hidden_dim,output_dim)
optimizer = optim.AdamW(net.parameters(),lr=lr,weight_decay=weight_decay)
#optimizer = optim.SGD(net.parameters(),lr = lr,momentum=momentum)
criterion = nn.CrossEntropyLoss()
print("****************Begin Training****************")
net.train()
for epoch in range(epochs):out = net(cora_data.x,cora_data.edge_index)optimizer.zero_grad()loss_train = criterion(out[cora_data.train_mask],cora_data.y[cora_data.train_mask])loss_val = criterion(out[cora_data.val_mask],cora_data.y[cora_data.val_mask])loss_train.backward()print('epoch',epoch+1,'loss-train {:.2f}'.format(loss_train),'loss-val {:.2f}'.format(loss_val))optimizer.step()net.eval()
out = net(cora_data.x,cora_data.edge_index)
loss_test = criterion(out[cora_data.test_mask],cora_data.y[cora_data.test_mask])
_,pred = torch.max(out,dim=1)
pred_label = pred[cora_data.test_mask]
true_label = cora_data.y[cora_data.test_mask]
acc = sum(pred_label==true_label)/len(pred_label)
print("****************Begin Testing****************")
print('loss-test {:.2f}'.format(loss_test),'acc {:.2f}'.format(acc))
参数设置
epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7
output_dim是输出维度,也就是有多少可能的类别。
注意事项
1.发现loss不下降:
建议改一改lr(学习率),我做的时候开始用的SGD,学习率设的0.01发现loss不下降,改成0.1后好了很多。如果用AdamW,0.001(1e-3)基本就够用了
运行图
查看全文
99%的人还看了
猜你感兴趣
版权申明
本文"PyG-GCN-Cora(在Cora数据集上应用GCN做节点分类)":http://eshow365.cn/6-9803-0.html 内容来自互联网,请自行判断内容的正确性。如有侵权请联系我们,立即删除!
- 上一篇: 【面试经典150 | 双指针】三数之和
- 下一篇: MySQL详解 四:MySQL的日志管理