ディープラーニングのモデルの軽量化(コンパクト化)について。
目次
蒸留
- 「蒸留」→層が深く高精度なネットワークが学習した知識を、層が浅く軽量なニューラルネットワークへと伝える方法
- 精度を損なわずにネットワークの圧縮をすることができる
- 計算コストの減少・正則化効果・クラス数や学習データが多い場合の学習の効率化が期待できる
- 「教師ネットワーク」→層が深く複雑な学習済みのネットワーク
- 「生徒ネットワーク」→教師が学んだ知識を伝える未学習のネットワーク
- 教師ネットワークの事後確率を正解ラベルとし、生徒ネットワークの学習を行い、知識を伝達する
- 「hard target」→通常学習時に用いるone-hotな正解ラベル
- 「soft target」→教師ネットワークの事後確率
- 正解クラス以外の値は、”正解クラスとの類似度”と捉えることができ、このような形での知識の獲得が、汎化能力の向上へとつながっている
枝刈り
- 「枝刈り」→ニューラルネットワークの接続の一部を切断する=重みの値を0にする処理
- 「マグニチュードベース」→重みの値が一定より小さければ0にする方法
(絶対値の小さいものから順番に削除) - 「勾配ベース」→勾配情報を利用する手法
(モデルにデータを入力→各クラスの確信度を出力→正解クラスの確信度から誤差を逆伝播→各重みの評価値を算出→評価値(感度)が小さいものから順に削除) - 学習と枝刈りを繰り返し、少しずつモデルサイズを圧縮していく
(いきなり重みを多く削除すると、精度に悪影響を及ぼす) - 「宝くじ仮説」→より良い初期値(宝くじ)を持つサブネットワークがモデル全体に含まれているという考え方。
→枝刈り後のモデルの初期値には、元のモデルの値を利用して学習する - 枝刈りを行って学習する際に、重みを初期値に戻す(重みを巻き戻す)→高い精度になる
量子化
- 「量子化」→重みなどのパラメータを、少ないビット数で表現し、モデルを圧縮する
- 使用するビット数を制限→ネットワーク構造を変えずにメモリ使用量を削減できる
- 「32ビット浮動小数点数」→ディープラーニングの学習時
- 「8ビット整数」→推論時には小さな値は不要→メモリ使用量の削減
- ディープラーニングのフレームワークで採用されている