Cross Entropy的数值稳定计算
今天在看centernet
的heatmap
损失函数时,发现他的损失和熵差不多,但是我用tf
的实现会导致loss
为Nan
,因此我看了下Cross Entropy
的计算优化,这里记录一下.
Tensorflow中的cross_entropy计算
令\(x = logits\),\(z = labels\): \[ \begin{aligned} & z * -\log(\text{sigmoid}(x)) + (1 - z) * -\log(1 - \text{sigmoid}(x)) \\ =& z * -\log(\frac{1}{1 + e^{-x}}) + (1 - z) * -\log(\frac{e^{-x}}{1 + e^{-x}}) \\ =& z * \log(1 + e^{-x}) + (1 - z) * (-\log(e^{-x}) + \log(1 + e^{-x})) \\ =& z * \log(1 + e^{-x}) + (1 - z) * (x + \log(1 + e^{-x}) \\ =& (1 - z) * x + \log(1 + e^{-x}) \\ =& x - x * z + \log(1 + e^{-x}) \\ =& \log(e^x) - x * z + \log(1 + e^{-x}) \\ =& - x * z + \log(1 + e^{x}) \end{aligned} \]
下面为了避免\(e^{x}\)数值溢出,因此优化为如下:
\[ \begin{aligned} & \log(1 + e^{x}) \\ =& \log(1 + e^{-|x|}) + \max(x, 0) \end{aligned} \]
NOTE: tensorflow
中有个专门的函数\(softplus(x)=\log(1 +
e^{x})\),其中已经包含了数值溢出的优化.
Centernet中的FocalLoss计算
先给出他的FocalLoss
部分代码:
def _neg_loss(pred, gt): |
NOTE:
注意这里的pred
是经过sigmoid
的.
将上述代码转换为公式,令\(x = logits\),\(z = labels\),\(x_s=\text{sigmoid}(x)\): \[ \begin{aligned} & -\log(\text{sigmoid}(x))*(1-x_s)^2-\log(1-\text{sigmoid}(x))* x_s^2\\ = & -\log(\frac{1}{1+e^{-x}})*(1-x_s)^2-\log(\frac{e^{-x}}{1+e^{-x}})* x_s^2\\ = & \log(1+e^{-x})*(1-x_s)^2+[-\log(e^{-x}) + \log(1 + e^{-x})]*x_s^2] \\ = & \text{softplus}(-x)*(1-x_s)^2+[x + \text{softplus}(-x)]*x_s^2] \end{aligned} \]
优化后对应代码为: def focal_loss(self, true_hm: tf.Tensor, pred_hm: tf.Tensor) -> tf.Tensor:
""" Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Parameters
----------
true_hm : tf.Tensor
shape : [batch, out_h , out_w, calss_num]
pred_hm : tf.Tensor
shape : [batch, out_h , out_w, calss_num]
Returns
-------
tf.Tensor
heatmap loss
shape : [batch,]
"""
z = true_hm
x = pred_hm
x_s = tf.sigmoid(pred_hm)
pos_inds = tf.cast(tf.equal(z, 1.), tf.float32)
neg_inds = 1 - pos_inds
neg_weights = tf.pow(1 - z, 4)
# neg entropy loss = −log(sigmoid(x)) ∗ (1−sigmoid(x))^2 − log(1−sigmoid(x)) ∗ sigmoid(x)^2
loss = tf.add(tf.nn.softplus(-x) * tf.pow(1 - x_s, 2) * pos_inds, (x + tf.nn.softplus(-x)) * tf.pow(x_s, 2) * neg_weights * neg_inds)
num_pos = tf.reduce_sum(pos_inds, [1, 2, 3])
loss = tf.reduce_sum(loss, [1, 2, 3])
return tf.div_no_nan(loss, num_pos)