본문 바로가기
카테고리 없음

ㅎㅎ

by 까다로운오리 2024. 8. 8.

class GCN_Unsupervised(nn.Module):
    def __init__(self):
        super(GCN_Unsupervised, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.pool = DmonPool(cluster_size=5)  # 클러스터 크기 지정

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x, edge_index, cluster_loss, _ = self.pool(x, edge_index)  # DmonPool 적용
        return x, cluster_loss

model = GCN_Unsupervised()