我本来打算用tensorflow 2.0去写capsule net的,结果被tf 2.0中的tensorboard坑的放弃了...

然后我换成tf 1.13去写了,下面做一个实现过程记录.

前言

其实我感觉capsule net应该是全连接层的升级版,以前的全连接层矩阵输入为x=[batch,m],现在capsule net的全连接层为x=[batch,m,len],就是将以前的一维特征扩展成了向量特征.

数据流动过程

其实我一开始的时候一直想不通数据流动时候的矩阵维度,所以卡壳了,下面就讲一下数据流动过程.

参考的网络结构是:

其实从reshape开始才是capsule net,1152为向量个数,8为向量长度.下一层连接到10个长为16的向量.最后计算向量的长度范数来匹配minsit的分类节点.

然后在capsDense中使用routing算法来更新c的值,最后输出.

踩过的坑

1. squash

问题描述

激活函数公式为:

vj=||sj||21+||sj||2sj||sj||

下面my_squash是我写的:

def my_squash(s):
s_norm = tf.norm_v2(s)
s_square_norm = tf.square(s_norm)
v = (s_square_norm * s)/((1+s_square_norm)*s_norm)
return v
然后我训练的时候一直没有效果,我找了半天才明白.

问题解决

问题在于tf.norm_v2这里的维度控制,他这里的范数指的是每一个向量的长度范数,所以需要指定维度为s_norm = tf.norm_v2(s, axis=-1, keepdims=True),下面是正确的维度演示: Sj=[batch,1152,8]||Sj||=[batch,1152,1]||Sj||2=[batch,1152,1]v=[batch,1152,8]

并且这个函数其实可以优化为: vj=||sj||21+||sj||2sj||sj||=||sj||sj1+||sj||2 代码为:

with tf.variable_scope('squash'):
s_norm = tf.norm_v2(s, axis=-1, keepdims=True)
s_square_norm = tf.square(s_norm)
v = (s_norm * s)/(1+s_square_norm)
return v