[Loss, Trouble Shooting] Cross Entropy Loss vs. NLL Loss
Pytorch 공식문서에는 Cross Entropy 가 LogSoftmax
와 NLLLoss
를 적용한 것과 같다고 적혀있다. 이것이 왜 같은 것인지 정리해보자.
Classification 과 Cross Entropy
- Classification 에서는 모델의 출력을 계산하고 cross entropy loss 를 적용한다. 이는 수학적으로 매우 타당하다.
- Classification 에서 가정하는 베르누이 분포나 카테고리 분포를 최대우도추정(MLE, Maximum Likelihood Estimation) 방식으로 추정할 때, Cross Entropy 와 유사한 형태의 loss function 이 나오게 된다.
- 베르누이 분포의 MLE 는 Binary Cross Entropy 와 동일한 수식 형태를 가진다.
- 카테고리 분포의 MLE 는 Categorical Cross Entropy 와 동일한 수식 형태를 가진다.
베르누이 분포의 MLE 와 Cross Entropy
- 베르누이 분포는 두 가지 결과만 가지는 확률 분포다. 예를 들어, 성공(1)과 실패(0)로 이루어진 데이터셋을 모델링하는데 사용된다.
-
베르누이 분포에서 데이터 포인트가 $y \in \lbrace 0, 1 \rbrace$ 이고, 성공 확률이 $p$ 라고 하면, 베르누이 분포의 확률질량함수는 아래와 같다.
\[P(y; p) = p^y (1 - p)^{1 - y}\] -
이제, 데이터셋 $\lbrace y^{(1)}, y^{(2)}, \dots, y^{(n)}\rbrace$ 에 대해 MLE 를 수행하려면, 각 데이터 포인트의 log likelihood 를 구하고 이를 최대화한다. log likelihood function 은 다음과 같다.
\[\log P(y; p) = y \log(p) + (1 - y) \log(1 - p)\] -
그렇다면 전체 데이터셋에 대한 log likelihood 는 아래와 같이 나타낼 수 있다.
\[\mathcal{L}(p) = \sum_{i=1}^n \left[ y^{(i)} \log(p) + (1 - y^{(i)}) \log(1 - p) \right]\] - 이 식은 Binary Cross Entropy loss fucntion 의 형태와 동일하다.
- 즉, 베르누이 분포에서 MLE 를 사용하여 likelihood 를 최대화하는 과정은 실제로 binary cross entropy 를 최소화하는 과정과 동일하다.
카테고리 분포의 MLE 와 Cross Entropy
- 카테고리 분포는 여러 클래스 중 하나의 결과를 가지는 분포다. 예를 들어, $q$ 개의 클래스가 있는 multi classification 문제에서 사용할 수 있다.
-
각 클래스에 속할 확률이 $\mathbf{p} = [p_1, p_2, \dots, p_q]$ 인 경우, 카테고리 분포의 확률질량함수는 다음과 같다.
\[P(y = j; \mathbf{p}) = p_j \quad \text{where} \quad j \in \{1, 2, \dots, q\}\] -
이 경우에도 log likelihood 를 사용해 MLE 를 수행할 수 있다. 데이터셋 $ \lbrace y^{(1)}, y^{(2)}, \dots, y^{(n)}\rbrace$ 에 대해 log likelihood function 은 다음과 같다.
\[\mathcal{L}(\mathbf{p}) = \sum_{i=1}^n \log P(y^{(i)}; \mathbf{p}) = \sum_{i=1}^n \log p_{y^{(i)}}\] -
이를 Cross Entropy loss 로 나타낼 수 있다.
\[\mathcal{L}(\mathbf{p}) = \sum_{i=1}^n \sum_{j=1}^q y_j^{(i)} \log p_j\] - 여기서 $y_j^{(i)}$ 는 해당 데이터 포인트 $i$ 가 클래스 $j$ 에 속하는지 여부를 나타내는 one-hot encoding 값이다.
- 위 식 역시 cross entropy 의 일반적인 형태다.
- 즉, 카테고리 분포에서 MLE 를 사용하여 likelihood 를 최대화하는 과정은 multi class cross entropy 를 최소화하는 과정과 동일하다.
Softmax 의 문제점
- 이 포스트에서 살펴봤듯, multi class cross entropy 식의 $p_j$ 는 모델의 출력인 logit 에 대해 각 클래스에 포함될 확률을 softmax 로 구한 것이다.
- 이 때 softmax 의 지수함수는 numerical under-flow 와 over-flow 를 유발하여 계산적으로 위험할 수 있다.
-
softmax 함수가 확률을 계산하는 방법은 아래와 같다.
\[\hat y_j = \frac{\exp(o_j)}{\sum_k \exp(o_k)}\] - 만약 일부 $o_k$ 값이 매우 큰 양수라면, $\exp(o_k)$ 는 컴퓨터가 특정 데이터 타입에서 가질 수 있는 가장 큰 숫자보다 커질 수 있다. 이를 over-flow 라고 한다.
- 반대로, 모든 매우 큰 음수라면, under-flow 가 발생할 수 있다.
- 예를 들어, single precision floating point(부동소수점) 숫자는 대략 $10^{-38}$ 에서 $10^{38}$ 사이의 범위를 다룬다.
- 따라서 $\mathbf{o}$ 에서 가장 큰 항이 $[-90, 90]$ 범위를 벗어나면 결과가 안정적이지 않게 된다.
LogSumExp trick
-
위 문제를 해결하는 방법은 $\bar{o} \stackrel{\textrm{def}}{=} \max_k o_k$ 를 모든 항에서 빼는 것이다. 이른바 LogSumExp trick 이다.
\[\hat y_j = \frac{\exp o_j}{\sum_k \exp o_k} = \frac{\exp(o_j - \bar{o}) \exp \bar{o}}{\sum_k \exp (o_k - \bar{o}) \exp \bar{o}} = \frac{\exp(o_j - \bar{o})}{\sum_k \exp (o_k - \bar{o})}\] - 최대값을 빼주기 때문에 모든 $j$ 에 대해 $o_j - \bar{o} \leq 0$ 임을 알 수 있다.
- 그러면 $q$ 개의 클래스를 가진 classification 문제에서 분모는 $[1, q]$ 범위 내에 있다.
- 또한, 분자는 절대로 1 을 초과하지 않으므로 over-flow 를 방지할 수 있다. 그리고 under-flow 는 $\exp(o_j - \bar{o})$ 가 수치적으로 0 으로 계산될 때만 발생한다.
- 그러나 우리가 $\log \hat{y}_j$ 를 계산하려고 할 때, $\log 0$ 이 문제가 되어 어려움을 겪을 수 있다. 이 때문에 역전파 과정에서
NaN
(Not a Number) 결과가 화면 가득히 나올 수 있다. - 다행인 것은 우리가 지수함수들을 계산하고 있지만, cross entropy 를 계산할 때 궁극적으로는 그 log 를 취하려고 한다는 점이다.
-
softmax 와 cross entropy 를 결합함으로써 우리는 이러한 수치적 안정성 문제를 완전히 회피할 수 있다.
\[\log \hat{y}_j = \log \frac{\exp(o_j - \bar{o})}{\sum_k \exp (o_k - \bar{o})} = o_j - \bar{o} - \log \sum_k \exp (o_k - \bar{o})\] - 이렇게 하면 over-flow 와 under-flow 를 모두 피할 수 있다.
- Pytorch 의 cross entropy loss 를 쓸 때는 softmax 확률을 loss 에 전달하는 대신, softmax 취하기 전 모델의 출력값인 logit 을 전달하고 softmax 와 그 log 를 cross entropy loss 함수 내에서 한 번에 계산한다.
- Pytorch 의 cross entropy 는 위와 같은 LogSumExp trick 을 써서 over-flow 와 under-flow 를 피한다.
log_softmax
- Pytorch 의 log_softmax 는 softmax 함수의 출력을 직접 계산한 다음에 log 를 취하는 것이 아니라, 이를 수치적으로 더 안정적으로 계산하기 위해 LogSumExp trick 을 사용한다.
- 이를 통해 under-flow 나 over-flow 와 같은 문제를 방지할 수 있다.
-
\[\log(\text{softmax}(o_j)) = \log \frac{\exp o_j}{\sum_k \exp o_k} = o_j - \log \left( \sum_k \exp o_k \right)\]log_softmax
는 softmax 에 log 를 취한 것이다. - 여기서 두번째 항은 직접 계산하면 매우 큰 값이 될 수 있는데, 이는 수치적으로 불안정할 수 있다.
-
따라서 이 계산을 안정적으로 하기 위해서 가장 큰 값을 빼준 후 계산하는 LogSumEXP trick 이 사용된다.
\[\log \left( \sum_k \exp o_k \right) = \bar{o} + \log \left( \sum_k \exp (o_k - \bar{o}) \right)\] - 여기서 $\bar{o}$ 는 $\max(o_k)$ 다.
-
이렇게 수치적 안정성을 위한 trick 이 적용되어 Pytorch 에서는 softmax 에 log 를 취할 때 아래와 같은 식이 기본적으로 수행되는 연산이라는 것이다.
\[\log \hat{y}_j = \log \frac{\exp(o_j - \bar{o})}{\sum_k \exp (o_k - \bar{o})} = o_j - \bar{o} - \log \sum_k \exp (o_k - \bar{o})\] - softmax 연산을 수행하면 elements 들이 0 에서 1 사이의 확률을 갖게 되고, 총합이 1 이 된다.
- Pytorch 에서는 이 softmax 의 결과를 직접적으로 loss 에 사용하면 수치적으로 불한정하여
NaN
값을 얻을 수 있으므로log_softmax
와NLL Loss
를 결합하는 것을 추천한다. - 또한 log 와 softmax 두 연산을 따로 수행하는 것보다 더 빠르고 수치적으로 안정적인
log_softmax
함수를 제공한다.
Cross Entropy loss
-
Pytorch 공식 문서의 Cross Entropy loss 식을 다시 보면 아래와 같다.
\[\ell(x, y) = L = \lbrace l_1, \ldots, l_N \rbrace^\top, \quad l_n = -\sum^C_{c=1} w_c \log \left(\frac{\exp(x_{n, c})}{\sum^C_{i=1}\exp(x_{n, i})}\right)y_{n,c}\] - 이는 카테고리 분포에 MLE 를 취했을 때 나오는 log likelihood function 의 부호를 바꿔서 minimize 하는 문제로 치환한 것임을 알 수 있다.
- 즉 negative log likelihood 를 이용해서 최적화하는 것이다.
- 또한 위 식을 자세히 보면 log 안의 식은 softmax 이기 때문에 log_softmax 가 쓰인 것을 알 수 있다.
- 따라서 Cross Entropy Loss 는 log_softmax 와 NLL loss 가 함께 사용된 것이다.
- 이러한 이유 때문에 모델의 마지막 output 에
log_softmax
를 사용하면 안된다고 하지만, 엄밀히 말하면 사용하지 말아야되는 것은 아니다. - 왜냐하면 $\text{log_softmax}(\text{log_softmax}(x)) = \text{log_softmax(x)}$ 이기 때문이다.
- 즉 output 에 log_softmax 를 취하지 않은 logit 을 Cross Entropy 에 넣으나, log_softmax 를 취하여 확률값으로 Cross Entropy 에 넣으나 연산 결과는 동일하다는 것이다.
- 따라서 cross entropy loss 를 사용할 때 해당 함수의 input 으로 model 의 raw output 인 logit 을 꼭 사용할 필요는 없다.
NLL loss
-
Pytorch 공식 문서의 NLL loss 식을 보면 아래와 같다.
\[\ell(x, y) = L = \lbrace l_1, \ldots, l_N \rbrace^\top, \quad l_n = -w_{y_n}x_{n, y_n}\] - 즉 NLL loss 안에서는 softmax 나 log 연산이 이뤄지지 않는다.
- 따라서 모델의 logit 을 loss 의 input 그대로 사용하는 것이 아니라
log_softmax
함수를 적용한 후 input 으로 사용해야 한다. - 이러한 이유로 NLL loss 를 사용하면 모델의 마지막 layer 에
log_softmax
가 있어야 한다.
코드 및 정리
-
cross entropy loss 를 사용할 때의 모델 형태를 코드로 보자.
class CNN(nn.module): def __init__(self): super(CNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 32, (3, 3), padding=1, stride=(1,1)), nn.ReLU(), nn.MaxPool2d((2,2), padding=0, stride=2) ) self.layer2 = nn.Sequential( nn.Conv2d(32, 64, (3, 3), padding=1, stride=(1,1)), nn.ReLU(), nn.MaxPool2d((2,2), padding=0, stride=(2,2)) ) self.fc = nn.Linear(7*7*64, 10, bias=True) torch.nn.init.xavier_uniform_(self.fc.weight) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.view(x.shape[0], -1) x = self.fc(x) return x
- forward 함수에 보면 fc layer 를 거친 logit 값을 출력하는 것을 볼 수 있다. 이를 Cross Entropy loss 에 넣어주는 것이다.
-
이제 NLL loss 를 사용할 때의 모델 형태를 코드로 보자.
class CNN(nn.module): def __init__(self): super(CNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 32, (3, 3), padding=1, stride=(1,1)), nn.ReLU(), nn.MaxPool2d((2,2), padding=0, stride=2) ) self.layer2 = nn.Sequential( nn.Conv2d(32, 64, (3, 3), padding=1, stride=(1,1)), nn.ReLU(), nn.MaxPool2d((2,2), padding=0, stride=(2,2)) ) self.fc = nn.Linear(7*7*64, 10, bias=True) torch.nn.init.xavier_uniform_(self.fc.weight) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.view(x.shape[0], -1) x = self.fc(x) return F.log_softmax(x)
- fc layer 를 지나
log_softmax
함수가 적용되는 것을 볼 수 있다.
- fc layer 를 지나
댓글 남기기