論文閱讀筆記: Unsupervised Data Augmentation for Consistency Training (2019)

甘樂
16 min readNov 15, 2020

--

論文連結: https://arxiv.org/pdf/1904.12848.pdf

前言:

UDA跟上篇介紹的MixMatch都是Google提出的非常優秀的DSSL (半監督式深度學習)方法, 而且兩篇發表時間非常接近且方法也很像(會另寫一篇文章詳述), 但UDA的效果比MixMatch更好, 而且UDA還有在NLP任務上實驗。

而且就算在labeled examples很大量的情況下, 使用UDA也會比純監督式學習還要好另外UDA也可與transfer learning結合很好

1. Introduction

現今有很多各式各樣的DSSL, 但基本上還是consistency training的概念(e.g. MixMatch, Mean-Teacher, VAT, Π-Model):

透過在input/hidden states增加noise, 但保持最終predictions不變, 作為model的regularizer, 因此也稱consistency regularization。

好的model應該對input/hidden states任何的noise injection都具robust, 傳統的noise有Gaussian noise/Dropout noise/Adversarial noise等。

其實在監督式學習上的SOTA DA方法, 同樣也會在consistency training上表現很好 (UDA原文有證明, 但這裡不討論), 所以UDA將noise injection替換成SOTA DA方法。

1-1. UDA的貢獻與發現:

  1. 在監督式學習上表現好的DA方法, 可作為consistency training中優質的noise injection。
  2. 同樣在大量labeled examples的情況下, UDA表現也比監督式學習還要好。
  3. UDA與transfer learning能結合的很不錯。
  4. 證明在監督式學習上的SOTA DA方法與UDA提升分類表現的關聯(此文不討論)。

2. Unsupervised Data Augmentation (UDA)

參數說明:
x: input
y*: true-label
Pθ(y|x): model/output distribition
θ: model參數
PL(x)/PU(x): labeled/unlabeled examples distribution
f*: prefect classifier
: augmented examples
q(x̂|x): augmentation transformation/noise operation

2.1 Background

DA會讓examples進行一些轉換, 但不改變其label, 因此valid q(x̂|x)也可被表示為x̂~q(x̂|x);由於x label一樣,可以透過negative log-likelihood來優化

若要提升model表現, augmented examples需要提供額外的inductive biases (歸納偏置, 以下為我的理解, 可參考此篇知乎):

歸納偏置代表model要預測一個example需做的推論(就像神經網路的weights, weights為學習訓練資料產生的規則, 而推論需要規則, augmented examples就需要提供更多的建立這些規則的資訊, 且這些資訊是原本訓練資料沒有的, 能讓model表現更好)。

但在監督式學習中的DA方法只能穩定提升且提升很有限, 因為DA在監督式學習只用在labeled examples上, 而這些labeled examples通常很少

但unlabeled examples通常很多, 因此UDA打算將這些SOTA DA方法應用在consistency training上。

2.2 UDA

使用consistency training的SSL, 可以簡單總結為以下2步驟, UDA也是:

  1. 對於x, 計算其Pθ(y|x) & Pθ(y|x, noise), 後者為在x或hidden states加入noise
  2. 最小化Pθ(y|x) & Pθ(y|x, noise)間的divergence
UDA架構圖

UDA主要關注noise operation影響consistency training performance這部分, 因為以往方法都只是加入簡單的noise在unlabeled examples, UDA假設在監督式的SOTA DA方法也能在consistency training中表現很好。

所以UDA用各任務上SOTA DA方法來在consistency training中作為noise injection, loss funtion如下:

loss function of UDA(+號左邊為supervised loss, 右邊為consistency loss, λ為平衡權重, 全部實驗設1), CE代表cross entropy, θ˜相等於θ但參數固定不會反向傳播
  1. supervised loss在labeled examples x1上計算negative log likelihood (也就是cross-entropy)
  2. consistency loss在計算unlabeled examples x2與其augmentation 的KL divergence
  3. 為何固定參數? 因為Pθ˜(y|x2)使用的是上個iteration更新後的參數(idea取自VAT)

2.3 Advanced data augmentation

Advanced data augmentation須符合以下3個條件(相較簡單的DA方法, 如圖片裁切、旋轉、翻轉等能提升更多model表現):

  1. Valid noise: 可保持consistency
  2. Diverse noise: 可大量修改原unlabeled examples但又保持consistency
  3. Targeted inductive biases: 可提供額外的inductive biases

2.4 Augmentation Strategies for Different Tasks

這裡只講原文中NLP的SOTA DA部分:

反向翻譯

1. Back-translation:
將原句翻譯成另種語言後再翻譯回來(使用WMT'14 English-French translation models), 能保留原句語義但敘述不同

生成不同的敘述對文本分類很重要, 因此UDA在翻譯時採用random sampling + tunable temperature取代beam search。(在3-2節介紹)

Problem: 有些字在翻譯前後都一直存在, 可以考慮使用Word-replacing

2. Word replacing with TF-IDF:
因為在文本分類中keywords很重要, 藉由替換掉low TF-IDF score的word來生成新樣本, 避免替換掉informative words(如keywords)讓label改變。(原文有詳細介紹但我看不太懂)

2.5 Additional Training Techniques

1.Confidence-based masking:

Consistency loss只計算最高的預測機率比threshold B還高的examples

2.Sharpening Predictions:

同MixMatch中的技巧, 讓unlabeled examples的預測分布更peak, 並結合confidence-based masking後的loss function (加在consistency loss上):

在一個mini-batch B下的loss, zy: logit of label y for example x, τ為超參數

3. Domain-relevance Data Filtering:

因為out-domain的unlabeled data比較好取得, 但out-domain的data distribution跟in-domain完全不一樣, 因此會先用in-domain data訓練baseline model來預測out-domain dataset, 只取出高信心值的out-domain data

UDA實際做法是先對每個類別以預測機率排列所有out-domain data, 並選取預測機率最高的out-domain data。

3. Extended Method Details

3.1 Training Signal Annealing for Low-data Regime (TSA)

在SSL訓練中, 其實model很容易快速地在labeled examples上overfitting, 同時在unlabeled data上underfitting。

所以UDA的想法是隨著training steps提升, 緩慢讓labeled data傳播training signal (supervised signal)。

具體作法:

  1. 設定一個threshold ηt, t為目前training step, 大小為1/K ~ 1(K為類別數量)
  2. 只傳回預測機率小於ηt的labeled examples x的training signals
  3. 如果labeled examples的正確類別預測機率, 即Pθ(y*|x), 大於threshold, 則從loss function中移除該labeled examples, 不讓model在簡單的labeled examples上過度訓練。
  4. ηt的變化週期有3種: log, linear, exp根據不同情況使用
    (1) 如果labeled data很少或problem很簡單, 可用exp週期, 讓訓練後期才大量放出supervised signal
    (2) 如果labeled data很多或使用效果好的regularization, 可用log週期
ηt的變化週期, y軸為αt, x軸為t/T, T為total training steps

3.2 Discussion on Trade-off Between Diversity and Validity for Data Augmentation

validity和diversity之間有trade-off, 因為diverse會大量修改原句, 很大機率會不符合原本的label, 所以在反向翻譯時使用tuning temperature of random sampling, 概念如下:

  1. 當temperature=0, 會生成很valid但與原句相同的敘述
  2. 當temperature=1, 會生成很diverse但完全不可讀的敘述

實驗發現將temperature設置成0.7, 0.8, 0.9表現最好。

4. Experiments

這裡只講NLP上的實驗, 實驗使用6個文本分類資料集

  1. 情感分類: IMDb, Yelp-2, Yelp-5, Amazon-2, Amazon-5
  2. 主題分類: DBPedia

4.1 Correlation between Supervised and Semi-supervised Performance

這實驗是驗證監督式的SOTA DA也能在SSL上表現不錯, 使用Yelp-5資料集, 比較了反向翻譯與Switchout (隨機從vocab選取token來替換原句中的token)

結果表明DA方法在監督式的表現與在SSL中的表現確實有正相關。

4.2 Evaluation on Text Classification Datasets

此實驗比較:

  1. UDA在NLP領域的表現, 與full supervised SOTA比較
  2. 測試UDA與unsupervised representation learning結合的效果, 考慮4種intialization schemes使用UDA與未使用UDA的表現:

a) random Transformer

b) BERT_BASE

c) BERT_LARGE

d) BERT_FINTUNE: BERT_LARGE在in-domain unlabeled data上fine-tuning

根據實驗結果有以下發現(BERT_FINTUNE沒DBpedia是因為BERT的訓練資料和DBpedia都來自Wikipedia corpus):

  1. 就算只有一點點labeled examples, 使用UDA就能與SOTA model + full supervised表現相當;尤其是2元情感分類任務, 在IMDb比SOTA model還優秀, 在Yelp-2和Amazon-2與SOTA表現相當。
  2. UDA跟transfer learning / representation learning結合性不錯, 使用UDA後皆有提升, 即使經過fine-tuning也一樣
  3. 在多類別分類的2.5k labeled examples(每個類別500 examples)設置上, full supervised跟UDA存在明顯差距, 因為多類別分類比二元分類難很多, improvment可做為future work

4.3 Ablation Studies

A. Unlabeled Data Size

可以發現在同樣數量的labeled examples下, 若減少unlabeled examples數量, 則表現明顯下降;而在同樣數量的unlabeled examples下, 減少labeled examples數量, 表現下降並不明顯 (但如果unlabeled examples跟labeled examples數量差不多就很明顯)。

由此可知unlabeled examples數量重要性>>>labeled examples數量。

B. TSA

使用Yelp-5資料集, 設置2.5k labeled examples & 6m unlabeled examples & random initialized transformer作為pre-trained representation

根據結果可知, 如果如果在最後幾個training steps才傳播supervised training signals, 則表現更好。

C. 更多不同數量的labeled examples

5. Experiment Details (Text Classification)

5.1 Datasets

UDA的unlabeled examples都源於自己剩下的資料+其他資料集, 雖然Yelp和Amazon的label distribution不太相似, 但是整個全用表現還是不錯的。

5.2 Preprocessing

input maximum sequence length = 512 subwords, 才能給BERT pretrained, 若長度超過512, 則保留後面512 subwords在IMDb表現較好。

5.3 Fine-tuning BERT on in-domain unsupervised data

使用in-domain unsupervised data fine-tune, lr設置2e-5, 5e-5, 1e-4, batch_size設置32, 64, 128, training steps設置30k, 100k, 300k, 根據BERT loss選取fine-tuned model。

5.4 Random initialized Transformer

hidden_layers設6, attention_heads設8, dropout rate (for attention and hidden states)設0.2, 並在Amazon-5, Yelp-5上用UDA個訓練500k, 1M steps

5.5 BERT hyperparameters

dropout rate設置0.1, lr設置1e-5, 2e-5, 5e-5, batch_size設置32, 128, 並在不同的data sizes上tune 30~100k steps

5.6 UDA hyperparameters

所有實驗上的loss function平衡權重λ設為1, labeled examples batch_size設為32, unlabeled examples batch_size設為224, 這樣可讓model可以train到更多unlabeled data;對於BERT_FINETUNE所有的unlabeled examples augmentation只生成一筆就好

6. Conclusion

這篇論文介紹DA其實跟SSL可以完美結合, 好的DA方法能提升SSL的表現;UDA使用SOTA DA方法來生成diverse但valid的examples來保持consistency, 讓model對於各種noise更robust;此外在NLP方面, 也顯示其與representation learning結合的有效性(e.g. BERT), 也展示能與full supervised相抗衡的表現。

最後作者也希望UDA在future works能推展到不同task當中。

結論:

  1. 比MixMatch簡單, 但UDA實驗比較嚴謹, 且表現提升非常明顯
  2. 我覺得比較特別的部分就是使用SOTA DA和TSA
  3. UDA藉由augmentation unlabeled data + consistency training的來獲取unsupervised signal
  4. 現實情況下SSL表現應該還是有限, 因此結合transfer learning很重要

--

--