유명하거나 제시된지 오래된 활성화 함수는 컴퓨팅 환경에서 생길 수 있는 여러 문제들이 발생하지 않도록 다 손을 보고 프레임워크에 구현되지만, 따끈따끈한 논문의 함수는 직접 구현해야 한다.
그런 경우 지수함수 연산의 역전파 과정에서 기울기가 무한대로 튀거나,
로그에 이상치가 입력되어 정의되지 않은 값을 리턴하거나,
x 범위를 나누기 위해 텐서 단위 비교 연산을 하기 위해 몸을 비튼다거나 하는 문제가 생기기 쉽다.
이 때 마지막의 경우 사용할 수 있는 함수로 where() 함수가 있다.
ELU를 예시로 들어 설명하겠다.
사용법
- 텐서플로우의 경우 tf.where(조건, True, False)
- 토치의 경우 torch.where(조건, True, False)
위 ELU 함수를 구현해보면: (토치 기준)
class ELU(nn.Module):
def __init__(self):
super(ELU, self).__init__()
self.__name__ = 'ELU'
def forward(self, x):
return torch.where(x > 0, x, torch.exp(x) - 1)
와 같이 사용할 수 있다. 중첩해서도 사용 가능하다.
다만 앞서 말한 것 처럼 지수함수를 사용할 때, x 값에 민감하게 미분 값이 변화하여 nan 손실 함수 값을 리턴할 때가 있는데, 이 경우는 where 함수로 해결할 수 없다. (구현 시 마스킹 논리의 문제인 것으로 보인다.)
이런 경우 mask 변수에 조건문 Boolean 값을 인가하고, 이를 사용해 where 함수 역할을 대체할 수 있다.
코드는 아래와 같다:
class ELU(nn.Module):
def __init__(self):
super(ELU, self).__init__()
self.__name__ = 'ELU'
def forward(self, x):
mask = (x > 0)
return mask * x + (~mask * (t.exp(~mask * x) - 1))
mask의 조건문이 True면 1이 되어 해당 항만 살리고, False이면 ~mask가 1이 되어 해당 항만 살리는 논리이다.
exp() 인자에도 ~mask가 곱해지는 이유는 역전파 과정 때문인데,
(~mask * (t.exp(x) - 1))의 최종 결과가 ~mask로 인해 0이 된다 하더라도 중간 과정에 exp(x)가 포함되어 있기 때문에 세부 연산을 하나하나 역전파하게 되면 exp(x)의 변화율을 참조해야 하고, 그 과정에서 loss가 무한대(nan)로 튈 수 있는 경우가 생긴다.
때문에 exp(~mask * x)로 이러한 경우까지 마스킹을 해주는 것이다.
Reference
https://github.com/pytorch/pytorch/issues/68425
'Coding > Python' 카테고리의 다른 글
[Python]JSON, CSV (0) | 2023.03.02 |
---|---|
[Torch] 사전 훈련된 모델 불러오기, 학습 여부 조정 (0) | 2023.01.17 |
[Tensorflow/Python] LSTM 레이어 사용 시 cuDNN 커널 사용 불가 오류 (0) | 2023.01.02 |