CapsNet实现以及踩坑
我本来打算用tensorflow 2.0
去写capsule net
的,结果被tf 2.0
中的tensorboard
坑的放弃了...
然后我换成tf 1.13
去写了,下面做一个实现过程记录.
前言
其实我感觉capsule net
应该是全连接层的升级版,以前的全连接层矩阵输入为x=[batch,m],现在capsule net
的全连接层为x=[batch,m,len],就是将以前的一维特征扩展成了向量特征.
数据流动过程
其实我一开始的时候一直想不通数据流动时候的矩阵维度,所以卡壳了,下面就讲一下数据流动过程.
其实从reshape
开始才是capsule net
,1152为向量个数,8为向量长度.下一层连接到10个长为16的向量.最后计算向量的长度范数来匹配minsit
的分类节点.
然后在capsDense
中使用routing
算法来更新c
的值,最后输出.
踩过的坑
1. squash
问题描述
激活函数公式为:
vj=||sj||21+||sj||2sj||sj||
下面my_squash
是我写的:
def my_squash(s): |
问题解决
问题在于tf.norm_v2
这里的维度控制,他这里的范数指的是每一个向量的长度范数,所以需要指定维度为s_norm = tf.norm_v2(s, axis=-1, keepdims=True)
,下面是正确的维度演示:
令Sj=[batch,1152,8]则||Sj||=[batch,1152,1]||Sj||2=[batch,1152,1]v=[batch,1152,8]
并且这个函数其实可以优化为: vj=||sj||21+||sj||2sj||sj||=||sj||sj1+||sj||2 代码为:
with tf.variable_scope('squash'): |