lossが突然跳ね上がる現象

セマンティックセグメンテーションのネットワークを訓練していると、学習がある程度進行すると減少していたlossが突然跳ね上がるように大きくなる現象が置きてしまった。

状況など

セマンティックセグメンテーションの訓練にnn.CrossEntropyLossを使う。ネットワークの最後では活性化(softmaxやsigmoid)を入れずに、logitのまま出力する。

そのときの学習のコードは

criterion = nn.CrossEntropyLoss()

for inputs, labels in enumerate(data_loader):
    outputs = model(inputs).to(device)
    loss = criterion(outputs, labels)

のようになる。このまま学習すると、softmaxに入れるモデル出力のlogitが大きくなりすぎて、softmaxに通したとき0になる要素が現れてしまい、logに通したときにInfに振れてしまう現象が起こるようである。

実はこれ、みんな大好きゼロから作るDeep Learningにも書かれている話だけど、実際に起こるとは思わなかった。ましてやPyTorchほどのライブラリがこの対策を組み込んでいないとは思っていなかったのでなかなか気づけなかった。

修正案

まずモデルの出力 nn.Softmax2d() を使うようにする。これは訓練でも推論でも両方使うので、共通化した方が良いと考えるから。

そしてロス関数を修正する。

class CrossEntropyLoss2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.NLLLoss()

    def forward(self, x, y):
        # 補足: セグメンテーション用に次元を (BATCH, C, H, W) から (C, BATCH * H * W) 変換している
        x = (x.permute(0, 2, 3, 1).contiguous().view(-1, NUM_CLASSES) + 1e-24).log()
        _, y = torch.max(y.permute(0, 2, 3, 1).contiguous().view(-1, NUM_CLASSES), -1)
        return self.loss_fn(x, y)

セグメンテーション用に次元を調整を入れ、微小地として1e-24を足すようにした。これで学習が安定するようになった。

PyTorchのデフォルトのF.log_softmax()も間にこの処理が挟まってないし、それに対応するオプションもないので、結構混乱した。