class GCN_Unsupervised(nn.Module):
def __init__(self):
super(GCN_Unsupervised, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, 16)
self.pool = DmonPool(cluster_size=5) # 클러스터 크기 지정
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = torch.relu(x)
x = self.conv2(x, edge_index)
x, edge_index, cluster_loss, _ = self.pool(x, edge_index) # DmonPool 적용
return x, cluster_loss
model = GCN_Unsupervised()
카테고리 없음