소프트맥스 함수는 출력층에서 사용되는 함수이다. 이 외에도 항등 함수, 시그모이드 함수가 존재한다. 항등 함수의 경우 입력값이 그대로 출력되는 함수로 회귀 모델을 만들 때 사용한다.
소프트맥스 함수는 다중 클래스 분류 모델을 만들 때 사용한다. 결과를 확률로 해석할 수 있게 변환해주는 함수로 높은 확률을 가지는 class로 분류한다. 이는 결과값을 정규화시키는 것으로도 생각할 수 있다.
$p_j = \frac{e^{z_j}}{\sum_{k=1}^{K} e^{z_j}}$
$j = 1,2, \dots ,K$
K는 클래스 수를 나타내며, $z_j$는 소프트맥스 함수의 입력값이다. $p_j$를 직관적으로 해석하면 $\frac{j번째 입력값}{입력값의 합}$으로 볼 수 있으며 따라서 확률 관점으로 볼 수 있다. 지수함수가 사용되는 이유는 미분이 가능하도록 하게 함이며, 입력값 중 큰 값은 더 크게 작은 값은 더 작게 만들어 입력벡터가 더 잘 구분되게 하기 위함이다.
지금부터 소프트맥스 함수가 적용되는 예시를 다룰 것이다. 이 예시는 wikidocs.net/21690을 참고 하였다.
밑의 그림에는 두가지 질문이 있다. 질문에 대한 답을 써내려가면서 소프트 맥스 함수에 대해 알아보자.
- 소프트맥스 함수의 입력으로 어떻게 바꿀까 ?
iris 데이터를 이용한 다중 클래스 분류 모델을 만드는 예시이다. 데이터는 4개의 독립변수를 가지는데 이는 모델이 4차원 벡터를 입력으로 받음을 의미한다. 그런데 소프트맥스 함수의 입력으로 사용되는 벡터의 차원은 분류하고자 하는 클래스의 개수가 되어야 하므로 어떤 가중치 연산을 통해 3차원 벡터로 변환되어야 합니다. 위의 그림에서는 소프트맥스 함수의 입력으로 사용되는 3차원 벡터를 z로 표현하였습니다.
4차원 데이터 벡터를 소프트맥스 함수의 입력 벡터로 차원을 축소하는 방법은 간단하다. 소프트맥스 함수의 입력 벡터 z의 차원수만큼 결과값이 나오도록 가중치 곱을 진행한다. 위의 그림에서 화살표는 총 (4 x 3 = 12) 12개이며 전부 다른 가중치를 가지고, 학습 과정에서 점차적으로 오차를 최소화하는 가중치로 값이 업데이트됩니다.
- 오차를 어떻게 구할까 ?
소프트맥스 함수의 출력은 분류하고자 하는 클래스의 갯수만큼 차원을 가지는 벡터로 각 원소는 0과 1사이의 값을 가지며, 이 각각은 특정 클래스가 정답일 확률을 나타낸다. 즉, 첫번째 원소인 $p_1$은 virginica가 정답일 확률, 두번째 원소인 $p_2$는 setosa가 정답일 확률, 세번째 원소인 $p_3$은 versicolor가 정답일 확률을 의미한다. 이제 이 예측값과 비교할 수 있는 실제값의 표현 방법이 있어야 한다. 소프트맥스 회귀에서는 실제을 원-핫 벡터로 표현한다.
맨 위 그림을 보면 viginica, setosa, versicolor가 1,2,3으로 인코딩된 것을 볼 수 있다. 이에 원-핫 인코딩을 수행하면 virginica는 1로 인코딩 되었기 때문에 첫번째 벡터만 1로 나타냈고, setosa는 2로 인코딩 되었기 떄문에 두번째 벡터만 1로 나타난 것을 알 수 있다.
데이터의 실제값이 setosa라면, setosa의 원-핫 벡터는 [0 1 0]이다. 이 경우, 예측값과 실제값의 오차가 0이 되는 경우는 소프트맥스 함수의 결과가 [0 1 0]이 되는 것이다. 이 두 벡터 [0.26 0.70 0.04] [0 1 0] 의 오차를 계산하기 위해서 소프트맥스 회귀는 손실함수로 cross-entropy 함수를 사용한다.
손실함수가 최소가 되는 방향으로 가중치를 업데이트 한다.
<logit . sigmoid, softmax 의 관계> - 이 글을 읽어보는 것을 추천드립니다.
- logit과 sigmoid는 서로 역함수 관계이고
- 2개 클래스 대상으로 정의하던 logit을 K개의 클래스를 대상으로 일반화하면 softmax함수가 유도된다.
- softmax함수에서 K=2로 두면 sigmoid함수로 환원이 되고, 반대로 sigmoid함수를 K개의 클래스로 일반화하면 softmax함수가 유도된다.
- 신경망에서 sigmoid는 활성화 함수로 softmax는 출력층에 사용되지만, 수학적으로는 서로 같은 함수이다.
<참고문헌>
choosunsick.github.io/post/softmax_function/
opentutorials.org/module/3653/22995
chacha95.github.io/2019-04-04-logit/
'Deep Learning > 딥러닝' 카테고리의 다른 글
활성화 함수(Activation Function) (0) | 2021.02.24 |
---|---|
퍼셉트론(Perceptron)과 오차역전파(Backpropagation) (1) | 2021.02.18 |
인공지능? 머신러닝? 딥러닝? (1) | 2021.02.16 |
댓글