[Loss, Trouble Shooting] Cross Entropy Loss vs. NLL Loss


Pytorch 공식문서에는 Cross EntropyLogSoftmaxNLLLoss 를 적용한 것과 같다고 적혀있다. 이것이 왜 같은 것인지 정리해보자.

Classification 과 Cross Entropy

  • Classification 에서는 모델의 출력을 계산하고 cross entropy loss 를 적용한다. 이는 수학적으로 매우 타당하다.
  • Classification 에서 가정하는 베르누이 분포카테고리 분포최대우도추정(MLE, Maximum Likelihood Estimation) 방식으로 추정할 때, Cross Entropy 와 유사한 형태의 loss function 이 나오게 된다.
    • 베르누이 분포의 MLE 는 Binary Cross Entropy 와 동일한 수식 형태를 가진다.
    • 카테고리 분포의 MLE 는 Categorical Cross Entropy 와 동일한 수식 형태를 가진다.

베르누이 분포의 MLE 와 Cross Entropy

  • 베르누이 분포는 두 가지 결과만 가지는 확률 분포다. 예를 들어, 성공(1)과 실패(0)로 이루어진 데이터셋을 모델링하는데 사용된다.
  • 베르누이 분포에서 데이터 포인트가 $y \in \lbrace 0, 1 \rbrace$ 이고, 성공 확률이 $p$ 라고 하면, 베르누이 분포의 확률질량함수는 아래와 같다.

    \[P(y; p) = p^y (1 - p)^{1 - y}\]
  • 이제, 데이터셋 $\lbrace y^{(1)}, y^{(2)}, \dots, y^{(n)}\rbrace$ 에 대해 MLE 를 수행하려면, 각 데이터 포인트의 log likelihood 를 구하고 이를 최대화한다. log likelihood function 은 다음과 같다.

    \[\log P(y; p) = y \log(p) + (1 - y) \log(1 - p)\]
  • 그렇다면 전체 데이터셋에 대한 log likelihood 는 아래와 같이 나타낼 수 있다.

    \[\mathcal{L}(p) = \sum_{i=1}^n \left[ y^{(i)} \log(p) + (1 - y^{(i)}) \log(1 - p) \right]\]
  • 이 식은 Binary Cross Entropy loss fucntion 의 형태와 동일하다.
  • 즉, 베르누이 분포에서 MLE 를 사용하여 likelihood 를 최대화하는 과정은 실제로 binary cross entropy 를 최소화하는 과정과 동일하다.

카테고리 분포의 MLE 와 Cross Entropy

  • 카테고리 분포는 여러 클래스 중 하나의 결과를 가지는 분포다. 예를 들어, $q$ 개의 클래스가 있는 multi classification 문제에서 사용할 수 있다.
  • 각 클래스에 속할 확률이 $\mathbf{p} = [p_1, p_2, \dots, p_q]$ 인 경우, 카테고리 분포의 확률질량함수는 다음과 같다.

    \[P(y = j; \mathbf{p}) = p_j \quad \text{where} \quad j \in \{1, 2, \dots, q\}\]
  • 이 경우에도 log likelihood 를 사용해 MLE 를 수행할 수 있다. 데이터셋 $ \lbrace y^{(1)}, y^{(2)}, \dots, y^{(n)}\rbrace$ 에 대해 log likelihood function 은 다음과 같다.

    \[\mathcal{L}(\mathbf{p}) = \sum_{i=1}^n \log P(y^{(i)}; \mathbf{p}) = \sum_{i=1}^n \log p_{y^{(i)}}\]
  • 이를 Cross Entropy loss 로 나타낼 수 있다.

    \[\mathcal{L}(\mathbf{p}) = \sum_{i=1}^n \sum_{j=1}^q y_j^{(i)} \log p_j\]
  • 여기서 $y_j^{(i)}$ 는 해당 데이터 포인트 $i$ 가 클래스 $j$ 에 속하는지 여부를 나타내는 one-hot encoding 값이다.
  • 위 식 역시 cross entropy 의 일반적인 형태다.
  • 즉, 카테고리 분포에서 MLE 를 사용하여 likelihood 를 최대화하는 과정은 multi class cross entropy 를 최소화하는 과정과 동일하다.

Softmax 의 문제점

  • 이 포스트에서 살펴봤듯, multi class cross entropy 식의 $p_j$ 는 모델의 출력인 logit 에 대해 각 클래스에 포함될 확률을 softmax 로 구한 것이다.
  • 이 때 softmax 의 지수함수는 numerical under-flow 와 over-flow 를 유발하여 계산적으로 위험할 수 있다.
  • softmax 함수가 확률을 계산하는 방법은 아래와 같다.

    \[\hat y_j = \frac{\exp(o_j)}{\sum_k \exp(o_k)}\]
  • 만약 일부 $o_k$ 값이 매우 큰 양수라면, $\exp(o_k)$ 는 컴퓨터가 특정 데이터 타입에서 가질 수 있는 가장 큰 숫자보다 커질 수 있다. 이를 over-flow 라고 한다.
  • 반대로, 모든 매우 큰 음수라면, under-flow 가 발생할 수 있다.
  • 예를 들어, single precision floating point(부동소수점) 숫자는 대략 $10^{-38}$ 에서 $10^{38}$ 사이의 범위를 다룬다.
  • 따라서 $\mathbf{o}$ 에서 가장 큰 항이 $[-90, 90]$ 범위를 벗어나면 결과가 안정적이지 않게 된다.

LogSumExp trick

  • 위 문제를 해결하는 방법은 $\bar{o} \stackrel{\textrm{def}}{=} \max_k o_k$ 를 모든 항에서 빼는 것이다. 이른바 LogSumExp trick 이다.

    \[\hat y_j = \frac{\exp o_j}{\sum_k \exp o_k} = \frac{\exp(o_j - \bar{o}) \exp \bar{o}}{\sum_k \exp (o_k - \bar{o}) \exp \bar{o}} = \frac{\exp(o_j - \bar{o})}{\sum_k \exp (o_k - \bar{o})}\]
  • 최대값을 빼주기 때문에 모든 $j$ 에 대해 $o_j - \bar{o} \leq 0$ 임을 알 수 있다.
  • 그러면 $q$ 개의 클래스를 가진 classification 문제에서 분모는 $[1, q]$ 범위 내에 있다.
  • 또한, 분자는 절대로 1 을 초과하지 않으므로 over-flow 를 방지할 수 있다. 그리고 under-flow 는 $\exp(o_j - \bar{o})$ 가 수치적으로 0 으로 계산될 때만 발생한다.
  • 그러나 우리가 $\log \hat{y}_j$ 를 계산하려고 할 때, $\log 0$ 이 문제가 되어 어려움을 겪을 수 있다. 이 때문에 역전파 과정에서 NaN(Not a Number) 결과가 화면 가득히 나올 수 있다.
  • 다행인 것은 우리가 지수함수들을 계산하고 있지만, cross entropy 를 계산할 때 궁극적으로는 그 log 를 취하려고 한다는 점이다.
  • softmax 와 cross entropy 를 결합함으로써 우리는 이러한 수치적 안정성 문제를 완전히 회피할 수 있다.

    \[\log \hat{y}_j = \log \frac{\exp(o_j - \bar{o})}{\sum_k \exp (o_k - \bar{o})} = o_j - \bar{o} - \log \sum_k \exp (o_k - \bar{o})\]
  • 이렇게 하면 over-flow 와 under-flow 를 모두 피할 수 있다.
  • Pytorch 의 cross entropy loss 를 쓸 때는 softmax 확률을 loss 에 전달하는 대신, softmax 취하기 전 모델의 출력값인 logit 을 전달하고 softmax 와 그 log 를 cross entropy loss 함수 내에서 한 번에 계산한다.
  • Pytorch 의 cross entropy 는 위와 같은 LogSumExp trick 을 써서 over-flow 와 under-flow 를 피한다.

log_softmax

  • Pytorch 의 log_softmax 는 softmax 함수의 출력을 직접 계산한 다음에 log 를 취하는 것이 아니라, 이를 수치적으로 더 안정적으로 계산하기 위해 LogSumExp trick 을 사용한다.
  • 이를 통해 under-flow 나 over-flow 와 같은 문제를 방지할 수 있다.
  • log_softmax 는 softmax 에 log 를 취한 것이다.

    \[\log(\text{softmax}(o_j)) = \log \frac{\exp o_j}{\sum_k \exp o_k} = o_j - \log \left( \sum_k \exp o_k \right)\]
  • 여기서 두번째 항은 직접 계산하면 매우 큰 값이 될 수 있는데, 이는 수치적으로 불안정할 수 있다.
  • 따라서 이 계산을 안정적으로 하기 위해서 가장 큰 값을 빼준 후 계산하는 LogSumEXP trick 이 사용된다.

    \[\log \left( \sum_k \exp o_k \right) = \bar{o} + \log \left( \sum_k \exp (o_k - \bar{o}) \right)\]
  • 여기서 $\bar{o}$ 는 $\max(o_k)$ 다.
  • 이렇게 수치적 안정성을 위한 trick 이 적용되어 Pytorch 에서는 softmax 에 log 를 취할 때 아래와 같은 식이 기본적으로 수행되는 연산이라는 것이다.

    \[\log \hat{y}_j = \log \frac{\exp(o_j - \bar{o})}{\sum_k \exp (o_k - \bar{o})} = o_j - \bar{o} - \log \sum_k \exp (o_k - \bar{o})\]
  • softmax 연산을 수행하면 elements 들이 0 에서 1 사이의 확률을 갖게 되고, 총합이 1 이 된다.
  • Pytorch 에서는 이 softmax 의 결과를 직접적으로 loss 에 사용하면 수치적으로 불한정하여 NaN 값을 얻을 수 있으므로 log_softmaxNLL Loss 를 결합하는 것을 추천한다.
  • 또한 log 와 softmax 두 연산을 따로 수행하는 것보다 더 빠르고 수치적으로 안정적인 log_softmax 함수를 제공한다.

Cross Entropy loss

  • Pytorch 공식 문서의 Cross Entropy loss 식을 다시 보면 아래와 같다.

    \[\ell(x, y) = L = \lbrace l_1, \ldots, l_N \rbrace^\top, \quad l_n = -\sum^C_{c=1} w_c \log \left(\frac{\exp(x_{n, c})}{\sum^C_{i=1}\exp(x_{n, i})}\right)y_{n,c}\]
  • 이는 카테고리 분포에 MLE 를 취했을 때 나오는 log likelihood function 의 부호를 바꿔서 minimize 하는 문제로 치환한 것임을 알 수 있다.
  • negative log likelihood 를 이용해서 최적화하는 것이다.
  • 또한 위 식을 자세히 보면 log 안의 식은 softmax 이기 때문에 log_softmax 가 쓰인 것을 알 수 있다.
  • 따라서 Cross Entropy Loss 는 log_softmax 와 NLL loss 가 함께 사용된 것이다.
  • 이러한 이유 때문에 모델의 마지막 output 에 log_softmax 를 사용하면 안된다고 하지만, 엄밀히 말하면 사용하지 말아야되는 것은 아니다.
  • 왜냐하면 $\text{log_softmax}(\text{log_softmax}(x)) = \text{log_softmax(x)}$ 이기 때문이다.
  • 즉 output 에 log_softmax 를 취하지 않은 logit 을 Cross Entropy 에 넣으나, log_softmax 를 취하여 확률값으로 Cross Entropy 에 넣으나 연산 결과는 동일하다는 것이다.
  • 따라서 cross entropy loss 를 사용할 때 해당 함수의 input 으로 model 의 raw output 인 logit 을 꼭 사용할 필요는 없다.

NLL loss

  • Pytorch 공식 문서의 NLL loss 식을 보면 아래와 같다.

    \[\ell(x, y) = L = \lbrace l_1, \ldots, l_N \rbrace^\top, \quad l_n = -w_{y_n}x_{n, y_n}\]
  • 즉 NLL loss 안에서는 softmax 나 log 연산이 이뤄지지 않는다.
  • 따라서 모델의 logit 을 loss 의 input 그대로 사용하는 것이 아니라 log_softmax 함수를 적용한 후 input 으로 사용해야 한다.
  • 이러한 이유로 NLL loss 를 사용하면 모델의 마지막 layer 에 log_softmax 가 있어야 한다.

코드 및 정리

  • cross entropy loss 를 사용할 때의 모델 형태를 코드로 보자.

    class CNN(nn.module):
        def __init__(self):
            super(CNN, self).__init__()
            self.layer1 = nn.Sequential(
              nn.Conv2d(1, 32, (3, 3), padding=1, stride=(1,1)),
              nn.ReLU(),
              nn.MaxPool2d((2,2), padding=0, stride=2)
            )
            self.layer2 = nn.Sequential(
              nn.Conv2d(32, 64, (3, 3), padding=1, stride=(1,1)),
              nn.ReLU(),
              nn.MaxPool2d((2,2), padding=0, stride=(2,2))
            )
            self.fc = nn.Linear(7*7*64, 10, bias=True)
    
            torch.nn.init.xavier_uniform_(self.fc.weight)
    
        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x)
            x = x.view(x.shape[0], -1)
            x = self.fc(x)
              
            return x
    
    • forward 함수에 보면 fc layer 를 거친 logit 값을 출력하는 것을 볼 수 있다. 이를 Cross Entropy loss 에 넣어주는 것이다.
  • 이제 NLL loss 를 사용할 때의 모델 형태를 코드로 보자.

    class CNN(nn.module):
        def __init__(self):
            super(CNN, self).__init__()
            self.layer1 = nn.Sequential(
              nn.Conv2d(1, 32, (3, 3), padding=1, stride=(1,1)),
              nn.ReLU(),
              nn.MaxPool2d((2,2), padding=0, stride=2)
            )
            self.layer2 = nn.Sequential(
              nn.Conv2d(32, 64, (3, 3), padding=1, stride=(1,1)),
              nn.ReLU(),
              nn.MaxPool2d((2,2), padding=0, stride=(2,2))
            )
            self.fc = nn.Linear(7*7*64, 10, bias=True)
    
            torch.nn.init.xavier_uniform_(self.fc.weight)
    
        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x)
            x = x.view(x.shape[0], -1)
            x = self.fc(x)
              
            return F.log_softmax(x)
    
    • fc layer 를 지나 log_softmax 함수가 적용되는 것을 볼 수 있다.
맨 위로 이동 ↑

댓글 남기기