PET-AI解读 | rs-fRMI的GNN和TCN建模(模型构建细节)

  • 相关论文:A deep graph neural network architecture for modelling spatio-temporal dynamics in resting-state functional MRI data
  • 相关repo:github.com/tjiagoM/spa…
  • 笔记人:陈亦新

主函数中生成了这样的模型:

model = SpatioTemporalModel(run_cfg=run_cfg,
                                encoding_model=None
                                ).to(run_cfg['device_run'])

这个SpatioTemporalModel十分的长,和以前解读工程一样,咱们只看forward函数就行,下面片段中的注释为我的理解:

class SpatioTemporalModel(nn.Module):
    def forward(self, data):
        # 这儿的三个数据,和咱们在上一末节解说的共同
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        if self.multimodal_size > 0:
            xn, x = x[:, :self.multimodal_size], x[:, self.multimodal_size:]
            xn = self.multimodal_lin(xn)
            xn = self.activation(xn)
            xn = self.multimodal_batch(xn)
            xn = F.dropout(xn, p=self.dropout, training=self.training)
        # Processing temporal part
        if self.conv_strategy != ConvStrategy.NONE:
            # 这儿似乎是吧LSTM也理解为Conv了
            if self.conv_strategy == ConvStrategy.LSTM:
                # 采用LSTM作为特征提取的办法
                x = x.view(-1, self.num_time_length, 1)
                # 能够见下面的LSTM-弥补1,便是用0初始化LSTM的隐含特征和cell state
                h0, c0 = self.init_lstm_hidden(x)
                # 可见下面LSTM-弥补2,一个LSTM模块
                x, (_, _) = self.temporal_conv(x, (h0, c0))
                x = x.contiguous()
            else:
                # 不是LSTM,那么便是卷积策略了。这儿卷积策略包含了一般的1D卷积,也包含了TCN的1D卷积模型。可见下方CNN-弥补1和TCN-弥补1
                x = x.view(-1, 1, self.num_time_length)
                x = self.temporal_conv(x)
            # Concatenating for the final embedding per node
            # 这个变量self.size_before_lin_temporal的数值,卷积通道x时刻序列长度。这时分卷积通道数现已扩大了8倍,时刻序列长度现已下采样了4次,变成本来的16分之1了。
            x = x.view(x.size()[0], self.size_before_lin_temporal)
            # 是一个全衔接层,也可能从_get_lin_temporal函数中得到的组件,详情能够看到下面的办法_get_lin_temporal
            x = self.lin_temporal(x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        elif self.encoding_strategy == EncodingStrategy.STATS:
        # 全衔接层self.stats_lin+1D BN层
            x = self.stats_lin(x)
            x = self.activation(x)
            x = self.stats_batch(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        elif self.encoding_strategy == EncodingStrategy.VAE3layers:
        # 这个也简单,便是VAE自编码器来做的特征提取
            mu, logvar = self.encoder_model.encode(x)
            x = self.encoder_model.reparameterize(mu, logvar)
        elif self.encoding_strategy == EncodingStrategy.AE3layers:
        # 和上面类似,是autoENcoder的
            x = self.encoder_model.encode(x)
        if self.multimodal_size > 0:
            x = torch.cat((xn, x), dim=1)
        # 到这一步的时分,咱们的x是现已从ts当中提取好的特征。
        # 图网络用了两个经典中的经典,GAT和GCN。GCN我之前有一篇ISBI的论文用的便是这个,后来就没再看过了。嘎嘎
        if self.sweep_type in [SweepType.GAT, SweepType.GCN]:
        # 总归,图网络的特征提取,其实和transformer的attention map十分类似。这儿在微观讲述模型结构的时分,暂时先不细讲,之后在仔细的考虑TCN和GNN的代码完成细节。
            if self.edge_weights:
                # 这个带上edge-weights的概念,也便是会输入两个节点之间的衔接的强弱。
                x = self.gnn_conv1(x, edge_index, edge_weight=edge_attr.view(-1))
            else:
                # 没有edgeweights的概念的,则是,只是告诉模型这两个节点有衔接有关系,可是并不会进一步的去诉说强弱
                x = self.gnn_conv1(x, edge_index)
            x = self.activation(x)
            x = F.dropout(x, training=self.training)
            # 看来这儿的图网络,也是一个十分浅层的,只有1层或许2层的网络。
            if self.num_gnn_layers == 2:
                if self.edge_weights:
                    x = self.gnn_conv2(x, edge_index, edge_weight=edge_attr.view(-1))
                else:
                    x = self.gnn_conv2(x, edge_index)
                x = self.activation(x)
                x = F.dropout(x, training=self.training)
        # 此外,作者还考虑了叫做PNANodeModel的特征提取器
        elif self.sweep_type == SweepType.META_NODE:
            x = self.meta_layer(x, edge_index, edge_attr)
        # 此外,作者还考虑了叫做MetaLayer的特征提取器
        elif self.sweep_type == SweepType.META_EDGE_NODE:
            x, edge_attr, _ = self.meta_layer(x, edge_index, edge_attr)
        # 这儿便是和上一章节解说的graph pool的方法,有均匀,相加和DiffPool
        if self.pooling == PoolingStrategy.MEAN:
            x = global_mean_pool(x, data.batch)
        elif self.pooling == PoolingStrategy.ADD:
            x = global_add_pool(x, data.batch)
        elif self.pooling in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED]:
        # 咱们还记得上一章遗留了一个问题,便是DiffPool只能处理稠密邻接矩阵,而咱们的是稀少的。所以转化的方法在这儿,可见下面的to_dense_ad部分
            adj_tmp = pyg_utils.to_dense_adj(edge_index, data.batch, edge_attr=edge_attr)
            if edge_attr is not None: # Because edge_attr only has 1 feature per edge
                adj_tmp = adj_tmp[:, :, :, 0]
            x_tmp, batch_mask = pyg_utils.to_dense_batch(x, data.batch)
            # self.diff_pool便是DiffPool这个组件,下一末节继续细讲
            x, link_loss, ent_loss = self.diff_pool(x_tmp, adj_tmp, batch_mask)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.activation(self.pre_final_linear(x))
        elif self.pooling == PoolingStrategy.CONCAT:
            x, _ = to_dense_batch(x, data.batch)
            x = x.view(-1, self.NODE_EMBED_SIZE * self.num_nodes)
            x = self.activation(self.pre_final_linear(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.final_linear(x)
        if self.final_sigmoid:
            return torch.sigmoid(x) if self.pooling not in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED] else (
                torch.sigmoid(x), link_loss, ent_loss)
        else:
            return x if self.pooling not in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED] else (x, link_loss, ent_loss)

关于上述代码段的弥补扩展:

  • LSTM-弥补1

            def init_lstm_hidden(x):
                h0 = torch.zeros(run_cfg['tcn_depth'], x.size(0), run_cfg['tcn_hidden_units'])
                c0 = torch.zeros(run_cfg['tcn_depth'], x.size(0), run_cfg['tcn_hidden_units'])
                return [t.to(x.device) for t in (h0, c0)]
  • LSTM-弥补2
self.temporal_conv = nn.LSTM(input_size=1,
                                         hidden_size=run_cfg['tcn_hidden_units'],
                                         num_layers=run_cfg['tcn_depth'],
                                         dropout=dropout_perc,
                                         batch_first=True)
  • CNN-弥补1
stride = 2
            padding = 3
            self.size_before_lin_temporal = self.channels_conv * 8 * self.final_feature_size
            self.lin_temporal = nn.Linear(self.size_before_lin_temporal, self.NODE_EMBED_SIZE - self.multimodal_size)
            self.conv1d_1 = nn.Conv1d(1, self.channels_conv, 7, padding=padding, stride=stride)
            self.conv1d_2 = nn.Conv1d(self.channels_conv, self.channels_conv * 2, 7, padding=padding, stride=stride)
            self.conv1d_3 = nn.Conv1d(self.channels_conv * 2, self.channels_conv * 4, 7, padding=padding, stride=stride)
            self.conv1d_4 = nn.Conv1d(self.channels_conv * 4, self.channels_conv * 8, 7, padding=padding, stride=stride)
            self.batch1 = BatchNorm1d(self.channels_conv)
            self.batch2 = BatchNorm1d(self.channels_conv * 2)
            self.batch3 = BatchNorm1d(self.channels_conv * 4)
            self.batch4 = BatchNorm1d(self.channels_conv * 8)
            self.temporal_conv = nn.Sequential(self.conv1d_1, self.activation, self.batch1, nn.Dropout(dropout_perc),
                                               self.conv1d_2, self.activation, self.batch2, nn.Dropout(dropout_perc),
                                               self.conv1d_3, self.activation, self.batch3, nn.Dropout(dropout_perc),
                                               self.conv1d_4, self.activation, self.batch4, nn.Dropout(dropout_perc))
            self.init_weights()
  • TCN-弥补1
#self.size_before_lin_temporal = self.channels_conv * 8 * self.final_feature_size
            #self.lin_temporal = nn.Linear(self.size_before_lin_temporal, self.NODE_EMBED_SIZE - self.multimodal_size)
            if run_cfg['tcn_hidden_units'] == 8:
                self.size_before_lin_temporal = self.channels_conv * (2 ** (run_cfg['tcn_depth'] - 1)) * self.num_time_length
            else:
                self.size_before_lin_temporal = run_cfg['tcn_hidden_units'] * self.num_time_length
            self.lin_temporal = self._get_lin_temporal(run_cfg)
            tcn_layers = []
            for i in range(run_cfg['tcn_depth']):
                if run_cfg['tcn_hidden_units'] == 8:
                    tcn_layers.append(self.channels_conv * (2 ** i) )
                else:
                    tcn_layers.append(run_cfg['tcn_hidden_units'])
            self.temporal_conv = TemporalConvNet(1,
                                                 tcn_layers,
                                                 kernel_size=run_cfg['tcn_kernel'],
                                                 dropout=self.dropout,
                                                 norm_strategy=run_cfg['tcn_norm_strategy'])
  • _get_lin_temporal
def _get_lin_temporal(self, run_cfg):
        if run_cfg['tcn_final_transform_layers'] == 1:
            lin_temporal = nn.Linear(self.size_before_lin_temporal,
                                          self.NODE_EMBED_SIZE - self.multimodal_size)
        elif run_cfg['tcn_final_transform_layers'] == 2:
            lin_temporal = nn.Sequential(
                nn.Linear(self.size_before_lin_temporal, int(self.size_before_lin_temporal / 2)),
                self.activation, nn.Dropout(self.dropout),
                nn.Linear(int(self.size_before_lin_temporal / 2), self.NODE_EMBED_SIZE - self.multimodal_size))
        elif run_cfg['tcn_final_transform_layers'] == 3:
            lin_temporal = nn.Sequential(
                nn.Linear(self.size_before_lin_temporal, int(self.size_before_lin_temporal / 2)),
                self.activation, nn.Dropout(self.dropout),
                nn.Linear(int(self.size_before_lin_temporal / 2), int(self.size_before_lin_temporal / 3)),
                self.activation, nn.Dropout(self.dropout),
                nn.Linear(int(self.size_before_lin_temporal / 3), self.NODE_EMBED_SIZE - self.multimodal_size))
        return lin_temporal
  • to_dense_adj
import torch_geometric.utils as pyg_utils
pyg_utils.to_dense_adj

这个办法的目的是:Converts batched sparse adjacency matrices given by edge indices and edge attributes to a single dense batched adjacency matrix。

官方文档的介绍地址在:torch_geometric.utils.to_dense_adj — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)

综上所述,便是时刻序列在这个模型当中经过的全部进程。先是对时刻序列进行编码,也便是抽取特征。抽取之后,挑选合适的图网络再此进行特征提取。最终使用DiffPool进行特征整合。