ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Task06:基于图神经网络的图表征学习方法

2021-07-05 21:03:50  阅读:360  来源: 互联网

标签:dim nn emb self 神经网络 128 表征 节点 Task06


Task06:基于图神经网络的图表征学习方法

本文参考datawhale开源学习资料

一、基于图同构网络(GIN)的图表征网络的实现

1. GINConv–图同构卷积层

提出图同构网络的论文是:How Powerful are Graph Neural Networks? ,可以参考【GNN】WL-test:GNN 的性能上界进行阅读。

图同构卷积层的数学定义如下:
x i ′ = h Θ ( ( 1 + ϵ ) ⋅ x i + ∑ j ∈ N ( i ) x j ) \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) xi′​=hΘ​⎝⎛​(1+ϵ)⋅xi​+j∈N(i)∑​xj​⎠⎞​
PyG中已经实现了此模块,我们可以通过torch_geometric.nn.GINConv来使用PyG定义好的图同构卷积层,此外,torch_geometric.nn.GINEConv支持存在边属性的图。在这里我们自己**自定义一个支持边属性的GINConv模块,详细描述可见datawhale开源资料

import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder

### GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(GINConv, self).__init__(aggr = "add")

        self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
        self.eps = nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)
        
    def update(self, aggr_out):
        return aggr_out

GINConv模块主要强调了以下几点:

  • aggregate即消息聚合函数需要为sum,使用max或者mean在某些情况下会对某些结构判断失败,如下图。
  • h Θ h_{\mathbf{\Theta}} hΘ​在代码中定义为self.mlp,即需要为至少两层的前馈神经网络。self.mlp可以近似拟合任何函数,summlp的结合能保证为多重集(mutiset)上的单射函数。
  • ϵ \epsilon ϵ可以自己定义,也可以通过学习得到,原文中作者将 ϵ \epsilon ϵ设置为0,在大部分的数据集上都能取得很好的效果。
    在这里插入图片描述

2. 基于图同构网络的节点嵌入模块

import torch
from ogb.graphproppred.mol_encoder import AtomEncoder
import torch.nn.functional as F

# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
        """GIN Node Embedding Module"""

        super(GINNodeEmbedding, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        # add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        # List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layers):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr

        # computing input node embedding
        h_list = [self.atom_encoder(x)]  # 先将类别型原子属性转化为原子表征
        for layer in range(self.num_layers):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        # Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers + 1):
                node_representation += h_list[layer]

        return node_representation
    
# 定义网络并打印
model = GINNodeEmbedding(num_layers=2, emb_dim=128)
print(model)

我们可以看到定义的网络结构如下:

GINNodeEmbedding(
  (atom_encoder): AtomEncoder(
    (atom_embedding_list): ModuleList(
      (0): Embedding(119, 128)
      (1): Embedding(4, 128)
      (2): Embedding(12, 128)
      (3): Embedding(12, 128)
      (4): Embedding(10, 128)
      (5): Embedding(6, 128)
      (6): Embedding(6, 128)
      (7): Embedding(2, 128)
      (8): Embedding(2, 128)
    )
  )
  (convs): ModuleList(
    (0): GINConv(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Linear(in_features=128, out_features=128, bias=True)
      )
      (bond_encoder): BondEncoder(
        (bond_embedding_list): ModuleList(
          (0): Embedding(5, 128)
          (1): Embedding(6, 128)
          (2): Embedding(2, 128)
        )
      )
    )
    (1): GINConv(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Linear(in_features=128, out_features=128, bias=True)
      )
      (bond_encoder): BondEncoder(
        (bond_embedding_list): ModuleList(
          (0): Embedding(5, 128)
          (1): Embedding(6, 128)
          (2): Embedding(2, 128)
        )
      )
    )
  )
  (batch_norms): ModuleList(
    (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

GINNodeEmbedding的参数说明:

3. 基于图同构网络的图表征模块

import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set

class GINGraphRepr(nn.Module):

    def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
        """GIN Graph Pooling Module
        Args:
            num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表征的维度,dimension of graph representation).
            num_layers (int, optional): number of GINConv layers. Defaults to 5.
            emb_dim (int, optional): dimension of node embedding. Defaults to 300.
            residual (bool, optional): adding residual connection or not. Defaults to False.
            drop_ratio (float, optional): dropout rate. Defaults to 0.
            JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
            graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".

        Out:
            graph representation
        """
        super(GINGraphPooling, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)

        # Pooling function to generate whole-graph embeddings
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=nn.Sequential(
                nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
        elif graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)
        output = self.graph_pred_linear(h_graph)

        if self.training:
            return output
        else:
            # At inference time, relu is applied to output to ensure positivity
            # 因为预测目标的取值范围就在 (0, 50] 内
            return torch.clamp(output, min=0, max=50)

与节点表征和链接预测学习不同的是,图表征学习需要一个额外的图池化(Graph Pooling),或称为图读出(Graph Readout)的模块,以得到图的表征(Graph Representation)。图池化可以分为层级池化(Pooling layers)全局池化(Global pooling layers)

上述代码中的graph_pooling="sum"代表从节点表征中得出全局表示,应满足节点的置换不变性。可选参数对应torch_geometric中全局池化的方法如下:

方法操作相关模块相关文献
sum对节点表征求和torch_geometric.nn.glob.global_add_pool——
mean对节点表征求平均torch_geometric.nn.glob.global_mean_pool——
max取节点表征的最大值torch_geometric.nn.glob.global_max_pool——
attention基于Attention对节点表征加权求和torch_geometric.nn.glob.GlobalAttention“Gated Graph Sequence Neural Networks”
set2set基于Attention对节点表征加权求和torch_geometric.nn.glob.Set2Set“Order Matters: Sequence to sequence for sets”

在图同构网络(GIN)中,Readout函数如下:
h G = CONCAT ( READOUT ( { h v ( k ) ∣ v ∈ G } ) ∣ k = 0 , 1 , ⋯   , K ) h_{G} = \text{CONCAT}(\text{READOUT}\left(\{h_{v}^{(k)}|v\in G\}\right)|k=0,1,\cdots, K) hG​=CONCAT(READOUT({hv(k)​∣v∈G})∣k=0,1,⋯,K)
其中的Readout可为求和操作,观察到使用了JK策略,并且是使用concat函数将各层特征拼接起来。

二、理论分析

1. 图同构性测试

两个图是同构的,意思是两个图拥有一样的拓扑结构,也就是说,我们可以通过重新标记节点从一个图转换到另外一个图。Weisfeiler-Lehman 图的同构性测试算法,简称WL Test,是一种用于测试两个图是否同构的算法。

WL Test 的一维形式,类似于图神经网络中的邻接节点聚合。WL Test 1)迭代地聚合节点及其邻接节点的标签,然后 2)将聚合的标签散列(hash)成新标签,该过程形式化为下方的公示,
L u h ← hash ⁡ ( L u h − 1 + ∑ v ∈ N ( U ) L v h − 1 ) L^{h}_{u} \leftarrow \operatorname{hash}\left(L^{h-1}_{u} + \sum_{v \in \mathcal{N}(U)} L^{h-1}_{v}\right) Luh​←hash⎝⎛​Luh−1​+v∈N(U)∑​Lvh−1​⎠⎞​
在上方的公示中, L u h L^{h}_{u} Luh​表示节点 u u u的第 h h h次迭代的标签,第 0 0 0次迭代的标签为节点原始标签。

在迭代过程中,发现两个图之间的节点的标签不同时,就可以确定这两个图是非同构的。需要注意的是节点标签可能的取值只能是有限个数。

WL测试不能保证对所有图都有效,特别是对于具有高度对称性的图,如链式图、完全图、环图和星图,它会判断错误。

Weisfeiler-Lehman Graph Kernels 方法提出用WL子树核衡量图之间相似性。该方法使用WL Test不同迭代中的节点标签计数作为图的表征向量,它具有与WL Test相同的判别能力。直观地说,在WL Test的第 k k k次迭代中,一个节点的标签代表了以该节点为根的高度为 k k k的子树结构。

Weisfeiler-Leman Test 算法举例说明:给定两个图 G G G和 G ′ G^{\prime} G′,每个节点拥有标签(实际中,一些图没有节点标签,我们可以以节点的度作为标签)。
在这里插入图片描述

Weisfeiler-Leman Test 算法通过重复执行以下给节点打标签的过程来实现图是否同构的判断

  1. 聚合自身与邻接节点的标签得到一串字符串,自身标签与邻接节点的标签中间用,分隔,邻接节点的标签按升序排序。排序的原因在于要保证单射性,即保证输出的结果不因邻接节点的顺序改变而改变。

在这里插入图片描述

  1. 标签散列,即标签压缩,将较长的字符串映射到一个简短的标签。
    在这里插入图片描述

  2. 给节点重新打上标签。

在这里插入图片描述

每重复一次以上的过程,就完成一次节点自身标签与邻接节点标签的聚合。

当出现两个图相同节点标签的出现次数不一致时,即可判断两个图不相似。如果上述的步骤重复一定的次数后,没有发现有相同节点标签的出现次数不一致的情况,那么我们无法判断两个图是否同构。

当两个节点的 h h h层的标签一样时,表示分别以这两个节点为根节点的WL子树是一致的。WL子树与普通子树不同,WL子树包含重复的节点。下图展示了一棵以1节点为根节点高为2的WL子树。

在这里插入图片描述

2. 图相似性评估

此方法来自于Weisfeiler-Lehman Graph Kernels

WL Test 算法的一点局限性是,它只能判断两个图的相似性,无法衡量图之间的相似性。要衡量两个图的相似性,我们用WL Subtree Kernel方法。该方法的思想是用WL Test算法得到节点的多层的标签,然后我们可以分别统计图中各类标签出现的次数,存于一个向量,这个向量可以作为图的表征。两个图的这样的向量的内积,即可作为这两个图的相似性的估计,内积越大表示相似性越高。

在这里插入图片描述

3. 图同构网络模型的构建

能实现判断图同构性的图神经网络需要满足,只在两个节点自身标签一样且它们的邻接节点一样时,图神经网络将这两个节点映射到相同的表征,即映射是单射性的。可重复集合(Multisets)指的是元素可重复的集合,元素在集合中没有顺序关系。 **一个节点的所有邻接节点是一个可重复集合,一个节点可以有重复的邻接节点,邻接节点没有顺序关系。**因此GIN模型中生成节点表征的方法遵循WL Test算法更新节点标签的过程。

在生成节点的表征后仍需要执行图池化(或称为图读出)操作得到图表征,最简单的图读出操作是做求和。由于每一层的节点表征都可能是重要的,因此在图同构网络中,不同层的节点表征在求和后被拼接,其数学定义如下,
h G = CONCAT ( READOUT ( { h v ( k ) ∣ v ∈ G } ) ∣ k = 0 , 1 , ⋯   , K ) h_{G} = \text{CONCAT}(\text{READOUT}\left(\{h_{v}^{(k)}|v\in G\}\right)|k=0,1,\cdots, K) hG​=CONCAT(READOUT({hv(k)​∣v∈G})∣k=0,1,⋯,K)
采用拼接而不是相加的原因在于不同层节点的表征属于不同的特征空间。未做严格的证明,这样得到的图的表示与WL Subtree Kernel得到的图的表征是等价的。

三、作业

  • 请画出下方图片中的6号、3号和5号节点的从1层到3层到WL子树。

在这里插入图片描述

答:

在这里插入图片描述

标签:dim,nn,emb,self,神经网络,128,表征,节点,Task06
来源: https://blog.csdn.net/m0_46306014/article/details/118498341

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有