半监督学习:MixMatch
第七个算法MixMatch: A Holistic Approach toSemi-Supervised Learning
。此算法将之前的各个半监督学习算法进行融合,统一了主流方法,得到了最优的效果。此算法好,就是训练的过程慢一些。
算法理论
半监督学习主要以未标记数据减轻对标记数据的要求,许多半监督学习方法都是添加根据无标签数据所产生的损失,从而使得模型更好的将未知标签数据分类。在最近的算法中,所添加的大部分损失属于以下三类之一,首先设置一个模型\(\text{P}_{model}(y|x\theta)\),可以从通过参数\(\theta\)从输入\(x\)中得到类别\(y\)。
熵最小化
代表,Pseudo-label: The simple and efficient semi-supervised learning method fordeep neural networks)。做法是鼓励模型输出无标签数据的标签。
在许多半监督学习方法中,一个常见的基本假设是分类器的决策边界不应通过边缘数据分布的高密度区域。强制执行此操作的一种方法是要求分类器对未标记的数据输出低熵预测。如论文(Semi-supervised learning by entropy minimization)使用这种这种损失函数使得\(\text{P}_{model}(y|x\theta)\)关于无标签数据\(x\)的熵最小,这种方法和
VAT
结合可以得到更强的效果。pseudo label
通过对无标签数据的高置信度预测构造出伪标签进行训练,从而隐式的最小化熵。MixMatch
通过在无标签数据的预测分布上使用sharpening
函数,也隐式的最小化熵。一致性正则
鼓励模型在其输入受到扰动时产生相同的输出分布。最简单的例子如下,对于无标签数据\(x\): \[ \begin{align} \| \text { Pmodel }(y | \text { Augment }(x) ; \theta)-\text { Pmodel }(y | \text { Augment }(x) ; \theta) \|_{2}^{2} \end{align}\tag{1} \] 注意\(\text { Augment }\)是随机变化,所以公式1中的两项是不一样的。
mean teacher
通过代替公式1中的一项,使用模型参数的滑动平均进行模型输出,这提供了一个更稳定的输出分布,并发现了过去经验可以改善当前结果。这些方法的缺点是它们使用特定于域的数据增强策略。VAT
则通过计算扰动来解决这个问题,将这个扰动添加到输入中,从而最大程度的改变输出类别的分布。MixMatch
则通过对图像进行标准的数据增强来添加一致性正则。一般正则化
传统正则算法鼓励模型得到更好的拟合与泛化效果。在
MixMatch
中对模型参数用l2
正则,同时数据增强使用mixup
。
MixMatch
就是一种集合上以上三种方法的新方法。总之,MixMatch
引入了针对无标签数据的统一的损失项,既可以减少熵又可以保持一致正则性,还与一般正则化方法兼容。
MixMatch
首先给出一系列符号。给定一个batch
的标记数据\(\mathcal{X}\)以及one-hot
标签,一个batch
的无标签数据\(\mathcal{U}\),通过数据增强得到\(\mathcal{X}',\mathcal{U}'\),然后对他们分别计算损失,最终整合所有损失:
\[ \begin{align} \mathcal{X}^{\prime}, \mathcal{U}^{\prime}=\operatorname{MixMatch}(\mathcal{X}, \mathcal{U}, T, K, \alpha) \end{align}\tag{2} \]
\[ \begin{align} \mathcal{L}_{\mathcal{X}}=\frac{1}{\left|\mathcal{X}^{\prime}\right|} \sum_{x, p \in \mathcal{X}^{\prime}} \mathrm{H}\left(p, \mathrm{p}_{\text {model }}(y | x ; \theta)\right) \end{align}\tag{3} \]
\[ \begin{align} \mathcal{L}_{\mathcal{U}}=\frac{1}{L\left|\mathcal{U}^{\prime}\right|} \sum_{u, q \in \mathcal{U}^{\prime}}\left\|q-\mathrm{p}_{\text {model }}(y | u ; \theta)\right\|_{2}^{2} \end{align}\tag{4} \]
\[ \begin{align} \mathcal{L}=\mathcal{L}_{\mathcal{X}}+\lambda_{\mathcal{U}} \mathcal{L}_{\mathcal{U}} \end{align}\tag{5} \]
其中\(H(p,q)\)是分布\(p\)和\(q\)间的交叉熵,\(T,K,\alpha,\lambda_{\mathcal{U}}\)是超参数。完整的MixMatch
如算法1所示。
现在来描述各个部分:
数据增强
对于一个
batch
\(\mathcal{X}\)中的每一个\(x_b\),通过变化得到\(\hat{x}_{b}=\text { Augment }\left(x_{b}\right)\)(算法1第3行)。对于每个无标签数据\(u_b\),我们生成\(K\)个增强\(\hat{u}_{b, k}=\text { Augment }\left(u_{b}\right), k \in(1, \ldots, K)\)(算法1第5行)。再使用每个\(u_b\)送入模型得到对应的猜测标签
\(q_b\)。标签猜测
有了
猜测标签
,我们将它用在无监督损失中,平均对\(u_b\)做\(K\)个增强的模型预测输出分布:\[ \begin{align} \bar{q}_{b}=\frac{1}{K} \sum_{k=1}^{K} \operatorname{Prodel}\left(y | \hat{u}_{b, k} ; \theta\right) \end{align}\tag{6} \]
sharpening: 为了达到对熵最小化的目的,我们需要对给定数据增强预测的平均值\(\bar{q}_{b}\)进行
sharpening
,通过sharpening
函数减小标签的分布熵。在代码中,是调整分类分布的温度
系数:\[ \begin{align} \text { Sharpen }(p, T)_{i}:=p_{i}^{\frac{1}{T}} / \sum_{j=1}^{L} p_{j}^{\frac{1}{T}} \end{align}\tag{7} \]
其中\(p\)是一些输入分类分布(在此算法中为\(\bar{q}_{b}\)),\(T\)是超参数。当\(T\rightarrow0\),\(\text{Sharpen}(p,T)\)的输出会趋近于
Dirac(one-hot)分布
,降低温度系数会鼓励模型产生较低熵的预测。mixup
要应用
mixup
,我们首先需要将所有的带标签的增强数据和所有无标签样本以及对应的猜测标签收集起来(算法1第10-11行): \[ \begin{align} \hat{\mathcal{X}}=\left(\left(\hat{x}_{b}, p_{b}\right) ; b \in(1, \ldots, B)\right) \end{align}\tag{12} \] \[ \begin{align} \hat{\mathcal{U}}=\left(\left(\hat{u}_{b, k}, q_{b}\right) ; b \in(1, \ldots, B), k \in(1, \ldots, K)\right) \end{align}\tag{13} \]然后我们联合以上分布并进行混洗得到新的数据集\(\mathcal{W}\)作为
mixup
的输入,对每第\(i\)个样本对\(\hat{\mathcal{X}}\),我们计算\(\operatorname{MixUp}\left(\hat{\mathcal{X}}_{i}, \mathcal{W}_{i}\right)\)并将结果添加到\(\mathcal{X}'\)(算法1第13行),对于\(i\in(1,\ldots,|\bar{\mathcal{U}}|)\)我们计算\(\mathcal{U}_{i}^{\prime}=\operatorname{MixUp}\left(\hat{\mathcal{U}}_{i}, \mathcal{W}_{i+|\hat{\mathcal{X}}|}\right)\) for \(i \in(1, \ldots,|\hat{\mathcal{U}}|)\)。在这个过程中,带标签数据可能会和无标签数据产生混合。损失函数
损失即标签数据的交叉熵结合无标签数据的差异性损失。
超参数
因为
MixMatch
结合了很多算法,所以超参数也特别的多,一般固定\(T=0.5,K=2\),然后\(\alpha=0.75,\lambda_{\mathcal{U}}=100\)
消融测试结果:
可以发现关键提升点在于锐化
以及无标签数据间的mixup
代码
|
测试结果
使用默认参数以及cifar10中250张标注样本训练128个epoch,得到测试集准确率如下:
"last01": 74.08999633789062, |
的确是超越之前算法太多了,就是训练时期的速度相对慢三倍。