論文閱讀筆記: MixMatch: A Holistic Approach to Semi-Supervised Learning (2019)
論文連結: https://arxiv.org/pdf/1905.02249.pdf
很推薦看此篇知乎, 評論也一起看, 我也會在文章內引用這篇的內容: 超強半監督學習MixMatch
前言:
在這篇"The Quiet Semi-Supervised Revolution", 由Google Principal Scientist撰寫的SSL (Semi-Supervised Learning)的未來展望文章中提到:
SSL在未來表現可能如下圖, 不管labeled data的數量有多少, 表現將一直高於監督式方法, 有機會刷新一波深度學習各任務的記錄。
MixMatch是作者提到的2篇超強SSL論文中的一篇(另一篇是UDA)。
1. Introduction
最近大部分的SSL主要是在loss function後加上計算unlabeled data的loss term, 讓整體模型更generalize, 這些loss term主要分為3種(會在Related Work簡介):
- Consistency Regularization
- Entropy Minimization
- Generic Regularization(Traditional Regularization)
而MixMatch只有一個loss term, 但此loss term包含了以上3種方法。
2. Related Work
2-1. Consistency Regularization
可參考此篇知乎: 半監督深度學習又小結之Consistency Regularization
相關論文: Temporal Ensembling, Virtual Adversarial Training, Mean Teachers
Consistency Regularization 的主要思想是:對於一個輸入,即使受到微小干擾,其預測都應該是一致的。
這裡的干擾可以當作data augmentation, 比如在做圖片分類時, 對label image做翻轉、平移、旋轉等, 但label不變, 就能讓model更泛化。
但unlabeled data沒有label, 如果要透過augmented unlabeled data來增加model的泛化程度, 就需要Consistency Regularization技術。
所以Consistency Regularization目的為: 讓unlabeled data經過augment後, classifier也應輸出相同的class distribution(保持consistency)。
正式的定義如下:
- unlabeled example x 跟 Augment(x)的分類結果應該是一樣的
- Augment(x)為隨機的transformation, 因此下式的兩項並不相同
2-2. Entropy Minimization
相關論文: Pseudo-Label, Virtual Adversarial Training
這個跟SSL的2大假設之low-density separation有關(decision boundary應位於資料點較少的區域), 也就是讓unlabeled data具有low-entropy的prediction(low-entropy代表class distribution很peak, 高機率屬於某一個類別)。
MixMatch將entropy minimization用在sharpening(產生guessed labels時會使用, 第3節會介紹)
2-3. Traditional Regularization
為了讓model泛化能力增加而加入, MixMatch使用(Weight decay, L2) + MixUp(第3節會解釋)
MixUp是從訓練集中任意選2個樣本, 產生混合樣本與混合標籤
MixUp可以同時作為Regularizer(應用於labeled data) & SSL method(應用於unlabeled data)
3. MixMatch
MixMatch是一個融合各大SSL主流思想(第2節介紹的)的”holistic”方法。
MixMatch整體流程如下圖:
Eq.2: 對於batch X的labeled data & batch U的unlabeled data(U size=X size)進行MixMatch, 產生augmented labeled data X’ & augmented unlabeled data U’ (包含guessed labels, 在3-2節會解釋)
Eq.3, 4: 分別計算X’, U’的loss(loss function會在3-4節解釋)
Eq.5: 以SSL方式結合兩個loss, 同Eq.5
MixMatch algorithm在3-3節, 在此之前先介紹MixMatch中各個component:
3–1. Data Augmentation
在MixMatch中, 只會對labeled data xb做1次augmentation, 對unlabeled data ub做K次augmentation(K為超參數), K個augmentations將用來產生guessed label qb。
因為實驗是圖像分類, 所以MixMatch的Data Augmentation使用隨機水平翻轉與裁切(Crop)。
3–2. Label Guessing
對於batch U中的所有unlabeled data都會產生guessed label, 如下圖:
- 對unlabeled data ub進行K次augmentation
- 將K個augmentations輸入同個classifier, 產生K個predictions
- 平均此K個predictions, 同Eq.6
- Sharpen平均後的結果, 同Eq.7
Sharpening
在這裡使用了entropy minimization, 透過sharpening來讓ub的class distribution更peak, 也讓預測結果更consistency。
調整distribution的temperature來實現:
當T越趨近0時, Sharpen function輸出會越趨近one-hot distribution。
3–3. MixUp
為了實現SSL, MixMatch會對labeled data & unlabeled data(有guessed label)都進行MixUp, MixMatch對MixUp有些微修改。
MixMatch中的MixUp algorithm:
- 對於任兩個data pair (x1, p1), (x2, p2)會計算 (x’, p’), 同Eq.10, Eq.11;其中p為labels probabilities
- λ為Beta distribution for MixUp(代表examples的混合比例, 為0~1的數), a為Beta distribution parameter(超參數), 同Eq.8
- MixMatch的MixUp調整在Eq.9, 原本的MixUp為λ' = λ (為以下第2點)
接下來介紹MixMatch中是如何使用MixUp的:
- Concat augmented labeled data batch Xˆ & augmented unlabeled data batch Uˆ (含guessed label);Xˆ只有一個batch, Uˆ有K個
- Shuffle concat collections得到W
- 先對Xˆ中每個examples, 與W中對應順序的examples進行MixUp (由於MixMatch修改了MixUp, 使得Xˆ中的examples權重較高)
- 再對K個Uˆ中每個examples, 與W中對應順序的examples進行MixUp(由W中的第|Xˆ|個開始, 不使用已與Xˆ進行MixUp的examples)
- Xˆ的MixUp結果放入X’, Uˆ的MixUp結果放入U'
MixMatch Algorithm結論:
1. MixMatch將 X 轉換成 X'
X’為augmentation+MixUp的labeled examples collections (可能與unlabeled data MixUp過了)
2. MixMatch將 U 轉換成 U’
U’為multiple augmentation+MixUp+guessed label的unlabeled examples collections (可能與labeled data MixUp過了)
所以這裡的unlabeled data有K個augmentation, 且相同exmaples的每個augmentation的guessed label都相同。
3-4. Loss Function
以下介紹各式中的參數含意:
Eq.3: |X’|=batch_size, H(p, Pmodel)為cross-entropy, x為augmented labeled data, p為true-label, Pmodel為model對x的預測
Eq.4: |U’|=K*batch_size, L=class數, u為augmented unlabeled data, q為guessed label, Pmodel為model對u的預測
而Lu使用L2 loss, 而不是像Lx一樣使用cross-entropy是因為L2 loss比cross-entropy更嚴格(解釋引用此篇文章, 下面貼的原文寫的很簡略):
Cross-entropy計算時需要先softmax, 將output score轉成distribution;但softmax對於常數疊加不敏感, 若output為score+c, 則softmax結果不變, 所以2個augmented unlabeled data在output score可以相差一個常數。
但L2 loss直接作用在output score上, 要求2次輸入需相等, 更為嚴格。
We use this L2 loss in eq. (4) (the multiclass Brier score [5]) because, unlike the cross-entropy, it is bounded and less sensitive to incorrect predictions. For this reason, it is often used as the unlabeled data loss in SSL [25, 44] as well as a measure of predictive uncertainty [26].We do not propagate gradients through computing the guessed labels, as is standard [25, 44, 31, 35]
以上原文提到SSL的unlabeled data通常都用L2 loss, 且因為guessed labels q依賴於模型參數θ, 因此不將q對θ的梯度反向傳播(重要的細節)。
Eq.5: λu = 非監督式loss function的weight(可調), 此篇使用100作為開始
3–5. Hyperparameters
超參數總共有4個, 可以固定且也不用pre-tuned:
- Sharpening temparature: T
- Number of unlabeled augmentations: K
- Beta distribution parameter for MixUp: a
- Unsupervised loss weight: λu
此實驗設置T=0.5, K=2, 只有pre-tuned a & λu (a=0.75, λu=100為不錯的起點, λu會線性上升直到16000 steps到達最大值)
其實這裡有2個重點(取自此篇知乎的評論區):
- λu會越來越高 -> 代表unlabeled data loss權重越大
- unlabeled data augment很多次, 但labeled data只有1次
以上2點表示想防止model在labeled data上overfitting(也是實務上的重點)
根據下圖Beta(0.75, 0.75)的分布圖, 大部分抽樣出的λ落在0, 1附近, 就能使λ'接近1, 讓MixUp時, Xˆ或Uˆ的權重很大。
4. Experiments
4-1. Implementation details
實驗中均使用Wide ResNet-28模型, 模型和訓練過程與此篇論文非常像。
不同點在於:
- Decaying: 使用exponential moving average(EMA)更新參數, decay rate為0.999(控制model的更新速度, tensorflow有method), 而不是decay lr
- 當每次Wide ResNet-28更新時使用0.0004 weight decay
- 設置checkpoint在每2¹⁶個training examples, 並report median error rate of the last 20 checkpoints
4-2. SSL Benchmarks
評估MixMatch在4個standard benchmark datasets: CIFAR-10, CIFAR-100, SVHN, STL-10。
前3個資料集評估SSL的標準做法為: 將training set中大部分examples作為unlabeled data, 只留一小部分做為labeled data (unlabeled data的來源應該是把examples的label直接去掉)。
而STL-10是專門為SSL設計的dataset, 有5000 labeled images & 100000 unlabeled images (unlabeled data distribution與labeled data distribution略有不同)。
4-2-1. Baseline Methods
使用此篇論文應用到的4種方法: Π-Model, Mean Teacher, Virtual Adversarial Training, Pseudo-Label。
另外還單獨使用了MixUp, 但MixUp被設計為用於監督式學習的regularizer, 於是將它以SSL方式修改: 應用於augmented labeled examples & augmented unlabeled examples with their corresponding predictions。
並計算MixUp生成的混合label與model對MixUp生成的混合examples的prediction間的cross-entropy。
4-2-2. Results
CIFAR-10
使用250, 500, 1000, 2000, 4000的labeled data來評估各方法準確性, 並設置λu = 75。
可發現MixMatch的性能大大優於所有其他方法, 4000個labeled data的錯誤率只有6.24%;在250個labeled data上, 次佳方法VAT的錯誤率達到36.03%,比MixMatch高4.5倍(MixMatch的錯誤率為11.08%)。
而且在4000個labeled data, 次佳方法Mean Teacher獲得了10.36%的錯誤率, MixMatch只需1/16倍的labeled data(250個)就能達到類似的性能。
另外將監督式學習做為參考, 對所有50000個樣本進行監督式訓練可獲得4.17%的錯誤率, 當作錯誤率的極限。
CIFAR-10 and CIFAR-100 with a larger model
由於baseline只有1.5 million-parameter, 想探討MixMatch在更大的model上的表現, 將28-layer Wide Resnet改為每層有135個filters, 共有26 million-parameter, 也進行labeled-data有10000個的實驗。
對於CIFAR-10, 使用λu= 75;對於CIFAR-100, 使用λU= 150。
SVHN and SVHN+Extra
使用λu=250, 其他設置類似CIFAR-10實驗。
可發現MixMatch的性能相對穩定(並且比所有其他方法都要好), 雖然Mean Teacher獲得極佳的性能, 但錯誤率始終比MixMatch更高。
其實SVHN有2個training sets: train and extra, 所以在 fully-supervised learning中, 要將這兩個sets串接起來 (604388 training examples), 但在SSL只使用73257 examples的train set。
由於SVHN + Extra資料較多, 使用α= 0.25,λu= 250和較低的weight decay 0.000002;在MixMatch幾乎要達到fully-supervised performance (All的column), 且MixMatch在SVHN + Extra上所有的表現皆優於fully supervised training on SVHN without extra (2.59% error)。
重點來了, 如果在SVHN (73257 examples)中如果只有250個labeled examples, 則可以考慮以下2種選擇:
- 獲取8倍以上的unlabeled data並使用 MixMatch
- 獲取293倍以上的labeled data並使用 fully-supervised learning
這個實驗結果指出, 獲取更多unlabeled data並使用MixMatch更有效。
STL-10
雖然上表的baseline並沒有使用相同的實驗設置, 但MixMatch的錯誤率幾乎是1/2倍, 所以沒關係;設置λu = 50。
4-2-3 Ablation Study
因為MixMatch使用了很多SSL的機制, 因此透過在CIFAR-10上做消融測試, 了解MixMatch的性能, 主要衡量以下5點:
- 使用對K個augmentation取mean class distribution (我想應該就是averaging)或使用single augmentation的class distribution (K=1)
- 移除temperature sharpening, T=1
- 產生guessed labels使用EMA
- 只在labeled examples或unlabeled examples使用MixUp, 或都不用
- 使用Interpolation Consistency Training: 消融測試的一種, 只使用unlabeled mixup, 無sharpening, 且產生guessed labels有使用EMA
4-3. 其他實驗: 比較其他SSL方法(可想而知MixMatch效果最好)
5. Conclusion
此篇論文介紹了一種半監督學習方法MixMatch, 它結合當前SSL主流的思想和機制; 透過實驗發現所有設置中, MixMatch與其他方法相比, 其表現均得到顯著改善, 其錯誤率通常降低了兩倍以上或更多。
未來也可以繼續探索與混合其他有趣的SSL方法, 看哪些機制可以組合成有效的新方法。
另外, 現在大多數SSL都是在圖片benchmark上評估的, 可以探索MixMatch在其他領域的效果。
結語:
- 基本上就是基於擾動的方法, 只是擾動(DA, Averaging, Sharpening, MixUp)更複雜一些, idea本身不複雜, 主要還是圍繞在consistency regularization, 但performance真的高很多
- MixUp的影響真的很大, 可以看看NLP中的MixUp該怎麼做
- DA方法也許可採用複雜一點的, 此篇只使用簡單的隨機水平翻轉和裁切
- UDA好像又比MixMatch表現還優秀, 下一篇就會看這篇, 看要如何跟MixMatch結合
- 大家可能覺得我這篇廢話太多, 我寫筆記的習慣還是傾向把原文翻譯過來(因為這篇是邊看邊寫), 但還是有加自己的東西啦QQ
- 我的論文基本上就是在找有沒有除了DA以外, 在labeled data有限情況下還能增加模型preformance的方法或小技巧(主要是SSL), 剛好這篇用到的小技巧都是提升performance的關鍵(averaging, sharpening, MixUp)
最後梳理一下重點:
- 利用labeled data & unlabeled data augmentation來增加模型泛化能力
- unlabeled data利用K次augmentation averaging & sharpening 得到guessed label, 來達成consistency regularization
- 將augment後的labeled data與unlabeled data MixUp, 得到混合樣本
- 以上作法完美融合unlabeled data與labeled data的訓練