lossが突然跳ね上がる現象

セマンティックセグメンテーションのネットワークを訓練していると、学習がある程度進行すると減少していたlossが突然跳ね上がるように大きくなる現象が置きてしまった。

状況など

基本的なセマンティックセグメンテーションの訓練にnn.CrossEntropyLossを使う。これを使うときはモデルの出力には活性関数を通さずに(softmaxやsigmoidなど)せずに、logitのまま出力する。

そのときの学習のコードは

criterion = nn.CrossEntropyLoss()

for inputs, labels in enumerate(data_loader):
    outputs = model(inputs).to(device)
    loss = criterion(outputs, labels)

のようになる。このまま学習すると、softmaxに入れるモデル出力のlogitが大きくなりすぎて、softmaxに通したとき0になる要素が現れてしまい、logに通したときにInfに振れてしまう現象が起こるようである。

修正案

nn.CrossEntropyLoss()softmax → cross-entrooy (NLL → log) を一挙に行う関数として提供されているので、これをsoftmax → NLL → 微小値付与 → log と行うようにする。

実装としては、

  1. softmax はモデル出力の最後にくっつける
  2. NLL → 微小値付与 → logを行うloss関数のモジュールを定義する

という感じにする。前者は、モデルの最後に活性化関数としてnn.Softmax2d()を通すようにする(そもそも、この方が評価時にも使いやすい)。後者は、以下のように実装した。

class CrossEntropyLoss2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.NLLLoss()

    def forward(self, x, y):
        # (BATCH, C, H, W) から (C, BATCH * H * W) に転置
        x = x.permute(0, 2, 3, 1).contiguous().view(-1, NUM_CLASSES)
        # 微小値を足す
        x = x + 1e-24
        x = x.log()

        y = y.permute(0, 2, 3, 1).contiguous().view(-1, NUM_CLASSES)
        _, y = torch.max(y, -1)
        return self.loss_fn(x, y)

xyはどちらも (Batch, Channel, Height, Width) で入ってくると想定している(各値はlogitではなく、p(確率))。NLLLossはsoftmaxのlogをしないバージョンで、 (Batch, Channel) 形式の入力を期待するので、transpose→flattenして変形する。

そしてlogに入れる前に、微小値として1e-24を足すようにした。厳密にはmax(p, 1e-24)の方が適切だが、加算でも学習できるし何より遅くなるのでシンプルに一律加算にしている。これで学習が安定するようになった。

ひとこと

みんな大好きゼロから作るDeep Learningにも書かれている話だけど、実際に起こるとは思わなかった。PyTorchのデフォルトのF.log_softmax()も間にこの処理が挟まってないし、それに対応するオプションもないので混乱した。