graph attention network(ICLR2018)官方代码详解(tensorflow)-稀疏矩阵版
时间:2022-07-24
本文章向大家介绍graph attention network(ICLR2018)官方代码详解(tensorflow)-稀疏矩阵版,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
论文地址:https://arxiv.org/abs/1710.10903
代码地址: https://github.com/Diego999/pyGAT
之前非稀疏矩阵版的解读:https://www.cnblogs.com/xiximayou/p/13622283.html
我们知道图的邻接矩阵可能是稀疏的,将整个图加载到内存中是十分耗费资源的,因此对邻接矩阵进行存储和计算是很有必要的。
我们已经讲解了图注意力网络的非稀疏矩阵版本,再来弄清其稀疏矩阵版本就轻松了,接下来我们将来看不同之处。
主运行代码在:execute_cora_sparse.py中
同样的,先加载数据:
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = process.load_data(dataset)
其中adj是coo_matrix类型,features是lil_matrix类型。
对于features,我们最终还是:
def preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.array(features.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
features = r_mat_inv.dot(features)
return features.todense(), sparse_to_tuple(features)
将其:
features, spars = process.preprocess_features(features)
转换为原始矩阵。
对于biases:
if sparse:
biases = process.preprocess_adj_bias(adj)
else:
adj = adj.todense()
adj = adj[np.newaxis]
biases = process.adj_to_bias(adj, [nb_nodes], nhood=1)
如果是稀疏格式的,就调用biases = process.preprocess_adj_bias(adj):
def preprocess_adj_bias(adj):
num_nodes = adj.shape[0] #2708
adj = adj + sp.eye(num_nodes) # self-loop 给对角上+1
adj[adj > 0.0] = 1.0 #大于0的值置为1
if not sp.isspmatrix_coo(adj):
adj = adj.tocoo()
adj = adj.astype(np.float32) #类型转换
indices = np.vstack((adj.col, adj.row)).transpose() # This is where I made a mistake, I used (adj.row, adj.col) instead
# return tf.SparseTensor(indices=indices, values=adj.data, dense_shape=adj.shape)
return indices, adj.data, adj.shape
这里看两个例子:
我们可以通过indices,data,shape来构造一个coo_matrix。
在定义计算图中的占位符时:
if sparse:
#bias_idx = tf.placeholder(tf.int64)
#bias_val = tf.placeholder(tf.float32)
#bias_shape = tf.placeholder(tf.int64)
bias_in = tf.sparse_placeholder(dtype=tf.float32)
else:
bias_in = tf.placeholder(dtype=tf.float32, shape=(batch_size, nb_nodes, nb_nodes))
使用bias_in = tf.sparse_placeholder(dtype=tf.float32)。
再接着就是模型中了,在utils文件夹下的layers.py中:
# Experimental sparse attention head (for running on datasets such as Pubmed)
# N.B. Because of limitations of current TF implementation, will work _only_ if batch_size = 1!
def sp_attn_head(seq, out_sz, adj_mat, activation, nb_nodes, in_drop=0.0, coef_drop=0.0, residual=False):
with tf.name_scope('sp_attn'):
if in_drop != 0.0:
seq = tf.nn.dropout(seq, 1.0 - in_drop)
seq_fts = tf.layers.conv1d(seq, out_sz, 1, use_bias=False)
# simplest self-attention possible
f_1 = tf.layers.conv1d(seq_fts, 1, 1)
f_2 = tf.layers.conv1d(seq_fts, 1, 1)
f_1 = tf.reshape(f_1, (nb_nodes, 1))
f_2 = tf.reshape(f_2, (nb_nodes, 1))
f_1 = adj_mat*f_1
f_2 = adj_mat * tf.transpose(f_2, [1,0])
logits = tf.sparse_add(f_1, f_2)
lrelu = tf.SparseTensor(indices=logits.indices,
values=tf.nn.leaky_relu(logits.values),
dense_shape=logits.dense_shape)
coefs = tf.sparse_softmax(lrelu)
if coef_drop != 0.0:
coefs = tf.SparseTensor(indices=coefs.indices,
values=tf.nn.dropout(coefs.values, 1.0 - coef_drop),
dense_shape=coefs.dense_shape)
if in_drop != 0.0:
seq_fts = tf.nn.dropout(seq_fts, 1.0 - in_drop)
# As tf.sparse_tensor_dense_matmul expects its arguments to have rank-2,
# here we make an assumption that our input is of batch size 1, and reshape appropriately.
# The method will fail in all other cases!
coefs = tf.sparse_reshape(coefs, [nb_nodes, nb_nodes])
seq_fts = tf.squeeze(seq_fts)
vals = tf.sparse_tensor_dense_matmul(coefs, seq_fts)
vals = tf.expand_dims(vals, axis=0)
vals.set_shape([1, nb_nodes, out_sz])
ret = tf.contrib.layers.bias_add(vals)
# residual connection
if residual:
if seq.shape[-1] != ret.shape[-1]:
ret = ret + conv1d(seq, ret.shape[-1], 1) # activation
else:
ret = ret + seq
return activation(ret) # activation
相应的位置都要使用稀疏的方式。
- Codeforces Beta Round #2 A,B,C
- 牛顿迭代法(Newton's Method)
- 最长递减子序列(nlogn)(个人模版)
- Selenium2+python自动化26-js处理内嵌div滚动条
- Selenium2+python自动化25-js处理日历控件
- 转负二进制(个人模版)
- Selenium2+python自动化24-js处理富文本
- 【干货】对抗自编码器PyTorch手把手实战系列——PyTorch实现对抗自编码器
- Selenium2+python自动化23-富文本(自动发帖)
- 2-Sat+输出可行解(个人模版)
- 协同过滤原理及Python实现
- 每周学点大数据 | No.25二叉搜索树回顾(二)
- RBF神经网络及Python实现(附源码)
- 【干货】计算机视觉实战系列03——用Python做图像处理
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 所谓并发编程,所谓有其三
- Redis 中的 3 种特殊数据类型
- 关于 “栈” 的那点破事
- 【为宏正名】for的妙用你想不到
- Centos 7 下 docker 导入导出镜像 实践笔记
- linux(centos)搭建.net core 运行环境
- 11g RAC 在线存储迁移实现 OCR 磁盘组完美替换
- MAC下 Centos7 下 免账号免密码便捷登录服务器的正确姿势 实践笔记
- AnimatedList 介绍及使用
- Flutter之SliverAppBar
- OpenGL ES 环境搭建
- Asp.Net Core 程序部署到Linux(centos)生产环境(一):普通部署
- Asp.Net Core 程序部署到Linux(centos)生产环境(二):docker部署
- docker-compose 安装jenkins的正确姿势 实践笔记
- windows安装nginx注册为服务的正确姿势 并设置开机自启 实践笔记