論文閱讀筆記: MixMatch: A Holistic Approach to Semi-Supervised Learning (2019)

甘樂
20 min readNov 8, 2020

--

論文連結: https://arxiv.org/pdf/1905.02249.pdf

很推薦看此篇知乎, 評論也一起看, 我也會在文章內引用這篇的內容: 超強半監督學習MixMatch

前言:

在這篇"The Quiet Semi-Supervised Revolution", 由Google Principal Scientist撰寫的SSL (Semi-Supervised Learning)的未來展望文章中提到:

SSL在未來表現可能如下圖, 不管labeled data的數量有多少, 表現將一直高於監督式方法, 有機會刷新一波深度學習各任務的記錄。

MixMatch是作者提到的2篇超強SSL論文中的一篇(另一篇是UDA)。

1. Introduction

最近大部分的SSL主要是在loss function後加上計算unlabeled data的loss term, 讓整體模型更generalize, 這些loss term主要分為3種(會在Related Work簡介):

  1. Consistency Regularization
  2. Entropy Minimization
  3. Generic Regularization(Traditional Regularization)

而MixMatch只有一個loss term, 但此loss term包含了以上3種方法。

2. Related Work

2-1. Consistency Regularization

可參考此篇知乎: 半監督深度學習又小結之Consistency Regularization

相關論文: Temporal Ensembling, Virtual Adversarial Training, Mean Teachers

Consistency Regularization 的主要思想是:對於一個輸入,即使受到微小干擾,其預測都應該是一致的。

這裡的干擾可以當作data augmentation, 比如在做圖片分類時, 對label image做翻轉、平移、旋轉等, 但label不變, 就能讓model更泛化。

但unlabeled data沒有label, 如果要透過augmented unlabeled data來增加model的泛化程度, 就需要Consistency Regularization技術。

所以Consistency Regularization目的為: 讓unlabeled data經過augment後, classifier也應輸出相同的class distribution(保持consistency)。

正式的定義如下:

  1. unlabeled example x 跟 Augment(x)的分類結果應該是一樣的
  2. Augment(x)為隨機的transformation, 因此下式的兩項並不相同
此為Consistency Regularization的loss term, x為unlabeled data, y為預測結果, θ為模型參數, 最右邊代表squared L2 loss

2-2. Entropy Minimization

相關論文: Pseudo-Label, Virtual Adversarial Training

這個跟SSL的2大假設之low-density separation有關(decision boundary應位於資料點較少的區域), 也就是讓unlabeled data具有low-entropy的prediction(low-entropy代表class distribution很peak, 高機率屬於某一個類別)。

MixMatch將entropy minimization用在sharpening(產生guessed labels時會使用, 第3節會介紹)

2-3. Traditional Regularization

為了讓model泛化能力增加而加入, MixMatch使用(Weight decay, L2) + MixUp(第3節會解釋)

MixUp是從訓練集中任意選2個樣本, 產生混合樣本與混合標籤

MixUp可以同時作為Regularizer(應用於labeled data) & SSL method(應用於unlabeled data)

3. MixMatch

MixMatch是一個融合各大SSL主流思想(第2節介紹的)的”holistic”方法。

MixMatch整體流程如下圖:

Eq.2: 對於batch X的labeled data & batch U的unlabeled data(U size=X size)進行MixMatch, 產生augmented labeled data X’ & augmented unlabeled data U’ (包含guessed labels, 在3-2節會解釋)

Eq.3, 4: 分別計算X’, U’的loss(loss function會在3-4節解釋)

Eq.5: 以SSL方式結合兩個loss, 同Eq.5

MixMatch algorithm在3-3節, 在此之前先介紹MixMatch中各個component:

3–1. Data Augmentation

在MixMatch中, 只會對labeled data xb做1次augmentation, 對unlabeled data ub做Kaugmentation(K為超參數), K個augmentations將用來產生guessed label qb

因為實驗是圖像分類, 所以MixMatch的Data Augmentation使用隨機水平翻轉與裁切(Crop)。

labeled data augmetaion
unlabeled data augmentation

3–2. Label Guessing

對於batch U中的所有unlabeled data都會產生guessed label, 如下圖:

  1. 對unlabeled data ub進行K次augmentation
  2. 將K個augmentations輸入同個classifier, 產生K個predictions
  3. 平均此K個predictions, 同Eq.6
  4. Sharpen平均後的結果, 同Eq.7

Sharpening

在這裡使用了entropy minimization, 透過sharpening來讓ub的class distribution更peak, 也讓預測結果更consistency。

調整distribution的temperature來實現:

p為input class distribution(也就是avg predictions), T為超參數

當T越趨近0時, Sharpen function輸出會越趨近one-hot distribution。

3–3. MixUp

為了實現SSL, MixMatch會對labeled data & unlabeled data(有guessed label)都進行MixUp, MixMatch對MixUp有些微修改。

MixMatch中的MixUp algorithm:

  1. 對於任兩個data pair (x1, p1), (x2, p2)會計算 (x’, p’), 同Eq.10, Eq.11;其中p為labels probabilities
  2. λ為Beta distribution for MixUp(代表examples的混合比例, 為0~1的數), a為Beta distribution parameter(超參數), 同Eq.8
  3. MixMatch的MixUp調整在Eq.9, 原本的MixUp為λ' = λ (為以下第2點)
MixUp algorithm in MixMatch

接下來介紹MixMatch中是如何使用MixUp的:

  1. Concat augmented labeled data batch Xˆ & augmented unlabeled data batch Uˆ (含guessed label);Xˆ只有一個batch, Uˆ有K個
  2. Shuffle concat collections得到W
  3. 先對Xˆ中每個examples, 與W中對應順序的examples進行MixUp (由於MixMatch修改了MixUp, 使得Xˆ中的examples權重較高)
  4. 再對K個Uˆ中每個examples, 與W中對應順序的examples進行MixUp(由W中的第|Xˆ|個開始, 不使用已與Xˆ進行MixUp的examples)
  5. Xˆ的MixUp結果放入X’, Uˆ的MixUp結果放入U'
在MixMatch Algorithm中使用MixUp
MixMatch Algorithm

MixMatch Algorithm結論:

1. MixMatch將 X 轉換成 X'

X’為augmentation+MixUp的labeled examples collections (可能與unlabeled data MixUp過了)

2. MixMatch將 U 轉換成 U’

U’為multiple augmentation+MixUp+guessed label的unlabeled examples collections (可能與labeled data MixUp過了)

所以這裡的unlabeled data有K個augmentation, 且相同exmaples的每個augmentation的guessed label都相同。

3-4. Loss Function

以下介紹各式中的參數含意:

Eq.3: |X’|=batch_size, H(p, Pmodel)為cross-entropy, x為augmented labeled data, p為true-label, Pmodel為model對x的預測

Eq.4: |U’|=K*batch_size, L=class數, u為augmented unlabeled data, q為guessed label, Pmodel為model對u的預測

而Lu使用L2 loss, 而不是像Lx一樣使用cross-entropy是因為L2 loss比cross-entropy更嚴格(解釋引用此篇文章, 下面貼的原文寫的很簡略):

Cross-entropy計算時需要先softmax, 將output score轉成distribution;但softmax對於常數疊加不敏感, 若output為score+c, 則softmax結果不變, 所以2個augmented unlabeled data在output score可以相差一個常數。

但L2 loss直接作用在output score上, 要求2次輸入需相等, 更為嚴格。

We use this L2 loss in eq. (4) (the multiclass Brier score [5]) because, unlike the cross-entropy, it is bounded and less sensitive to incorrect predictions. For this reason, it is often used as the unlabeled data loss in SSL [25, 44] as well as a measure of predictive uncertainty [26].We do not propagate gradients through computing the guessed labels, as is standard [25, 44, 31, 35]

以上原文提到SSL的unlabeled data通常都用L2 loss, 且因為guessed labels q依賴於模型參數θ, 因此不將q對θ的梯度反向傳播(重要的細節)。

Eq.5: λu = 非監督式loss function的weight(可調), 此篇使用100作為開始

3–5. Hyperparameters

超參數總共有4個, 可以固定且也不用pre-tuned:

  1. Sharpening temparature: T
  2. Number of unlabeled augmentations: K
  3. Beta distribution parameter for MixUp: a
  4. Unsupervised loss weight: λu

此實驗設置T=0.5, K=2, 只有pre-tuned a & λu (a=0.75, λu=100為不錯的起點, λu會線性上升直到16000 steps到達最大值)

其實這裡有2個重點(取自此篇知乎的評論區):

  1. λu會越來越高 -> 代表unlabeled data loss權重越大
  2. unlabeled data augment很多次, 但labeled data只有1次

以上2點表示想防止model在labeled data上overfitting(也是實務上的重點)

根據下圖Beta(0.75, 0.75)的分布圖, 大部分抽樣出的λ落在0, 1附近, 就能使λ'接近1, 讓MixUp時, Xˆ或Uˆ的權重很大。

Beta(0.75, 0.75)的分布, 來源: 超強半監督學習MixMatch

4. Experiments

4-1. Implementation details

實驗中均使用Wide ResNet-28模型, 模型和訓練過程與此篇論文非常像。

不同點在於:

  1. Decaying: 使用exponential moving average(EMA)更新參數, decay rate為0.999(控制model的更新速度, tensorflow有method), 而不是decay lr
  2. 當每次Wide ResNet-28更新時使用0.0004 weight decay
  3. 設置checkpoint在每2¹⁶個training examples, 並report median error rate of the last 20 checkpoints

4-2. SSL Benchmarks

評估MixMatch在4個standard benchmark datasets: CIFAR-10, CIFAR-100, SVHN, STL-10。

前3個資料集評估SSL的標準做法為: 將training set中大部分examples作為unlabeled data, 只留一小部分做為labeled data (unlabeled data的來源應該是把examples的label直接去掉)。

而STL-10是專門為SSL設計的dataset, 有5000 labeled images & 100000 unlabeled images (unlabeled data distribution與labeled data distribution略有不同)。

4-2-1. Baseline Methods

使用此篇論文應用到的4種方法: Π-Model, Mean Teacher, Virtual Adversarial Training, Pseudo-Label

另外還單獨使用了MixUp, 但MixUp被設計為用於監督式學習的regularizer, 於是將它以SSL方式修改: 應用於augmented labeled examples & augmented unlabeled examples with their corresponding predictions。

並計算MixUp生成的混合label與model對MixUp生成的混合examples的prediction間的cross-entropy。

4-2-2. Results

CIFAR-10

使用250, 500, 1000, 2000, 4000的labeled data來評估各方法準確性, 並設置λu = 75。

可發現MixMatch的性能大大優於所有其他方法, 4000個labeled data的錯誤率只有6.24%;在250個labeled data上, 次佳方法VAT的錯誤率達到36.03%,比MixMatch高4.5倍(MixMatch的錯誤率為11.08%)。

而且在4000個labeled data, 次佳方法Mean Teacher獲得了10.36%的錯誤率, MixMatch只需1/16倍的labeled data(250個)就能達到類似的性能。

另外將監督式學習做為參考, 對所有50000個樣本進行監督式訓練可獲得4.17%的錯誤率, 當作錯誤率的極限。

CIFAR-10 and CIFAR-100 with a larger model

由於baseline只有1.5 million-parameter, 想探討MixMatch在更大的model上的表現, 將28-layer Wide Resnet改為每層有135個filters, 共有26 million-parameter, 也進行labeled-data有10000個的實驗。

對於CIFAR-10, 使用λu= 75;對於CIFAR-100, 使用λU= 150。

SVHN and SVHN+Extra

使用λu=250, 其他設置類似CIFAR-10實驗。

可發現MixMatch的性能相對穩定(並且比所有其他方法都要好), 雖然Mean Teacher獲得極佳的性能, 但錯誤率始終比MixMatch更高。

其實SVHN有2個training sets: train and extra, 所以在 fully-supervised learning中, 要將這兩個sets串接起來 (604388 training examples), 但在SSL只使用73257 examples的train set。

由於SVHN + Extra資料較多, 使用α= 0.25,λu= 250和較低的weight decay 0.000002;在MixMatch幾乎要達到fully-supervised performance (All的column), 且MixMatch在SVHN + Extra上所有的表現皆優於fully supervised training on SVHN without extra (2.59% error)。

重點來了, 如果在SVHN (73257 examples)中如果只有250個labeled examples, 則可以考慮以下2種選擇:

  1. 獲取8倍以上的unlabeled data並使用 MixMatch
  2. 獲取293倍以上的labeled data並使用 fully-supervised learning

這個實驗結果指出, 獲取更多unlabeled data並使用MixMatch更有效。

STL-10

雖然上表的baseline並沒有使用相同的實驗設置, 但MixMatch的錯誤率幾乎是1/2倍, 所以沒關係;設置λu = 50。

4-2-3 Ablation Study

因為MixMatch使用了很多SSL的機制, 因此透過在CIFAR-10上做消融測試, 了解MixMatch的性能, 主要衡量以下5點:

  1. 使用對K個augmentation取mean class distribution (我想應該就是averaging)或使用single augmentation的class distribution (K=1)
  2. 移除temperature sharpening, T=1
  3. 產生guessed labels使用EMA
  4. 只在labeled examples或unlabeled examples使用MixUp, 或都不用
  5. 使用Interpolation Consistency Training: 消融測試的一種, 只使用unlabeled mixup, 無sharpening, 且產生guessed labels有使用EMA
可以看出貢獻程度為: MixUp > sharpening > Averaging

4-3. 其他實驗: 比較其他SSL方法(可想而知MixMatch效果最好)

5. Conclusion

此篇論文介紹了一種半監督學習方法MixMatch, 它結合當前SSL主流的思想和機制; 透過實驗發現所有設置中, MixMatch與其他方法相比, 其表現均得到顯著改善, 其錯誤率通常降低了兩倍以上或更多。

未來也可以繼續探索與混合其他有趣的SSL方法, 看哪些機制可以組合成有效的新方法。

另外, 現在大多數SSL都是在圖片benchmark上評估的, 可以探索MixMatch在其他領域的效果。

結語:

  1. 基本上就是基於擾動的方法, 只是擾動(DA, Averaging, Sharpening, MixUp)更複雜一些, idea本身不複雜, 主要還是圍繞在consistency regularization, 但performance真的高很多
  2. MixUp的影響真的很大, 可以看看NLP中的MixUp該怎麼做
  3. DA方法也許可採用複雜一點的, 此篇只使用簡單的隨機水平翻轉和裁切
  4. UDA好像又比MixMatch表現還優秀, 下一篇就會看這篇, 看要如何跟MixMatch結合
  5. 大家可能覺得我這篇廢話太多, 我寫筆記的習慣還是傾向把原文翻譯過來(因為這篇是邊看邊寫), 但還是有加自己的東西啦QQ
  6. 我的論文基本上就是在找有沒有除了DA以外, 在labeled data有限情況下還能增加模型preformance的方法或小技巧(主要是SSL), 剛好這篇用到的小技巧都是提升performance的關鍵(averaging, sharpening, MixUp)

最後梳理一下重點:

  1. 利用labeled data & unlabeled data augmentation來增加模型泛化能力
  2. unlabeled data利用K次augmentation averaging & sharpening 得到guessed label, 來達成consistency regularization
  3. 將augment後的labeled data與unlabeled data MixUp, 得到混合樣本
  4. 以上作法完美融合unlabeled data與labeled data的訓練

--

--