본문 바로가기

Coding/Python

[Torch/Tensorflow] 활성화 함수 구현과 where() 함수

유명하거나 제시된지 오래된 활성화 함수는 컴퓨팅 환경에서 생길 수 있는 여러 문제들이 발생하지 않도록 다 손을 보고 프레임워크에 구현되지만, 따끈따끈한 논문의 함수는 직접 구현해야 한다.

 

그런 경우 지수함수 연산의 역전파 과정에서 기울기가 무한대로 튀거나,

로그에 이상치가 입력되어 정의되지 않은 값을 리턴하거나,

x 범위를 나누기 위해 텐서 단위 비교 연산을 하기 위해 몸을 비튼다거나 하는 문제가 생기기 쉽다.

 

이 때 마지막의 경우 사용할 수 있는 함수로 where() 함수가 있다.

ELU를 예시로 들어 설명하겠다.

 

사용법

- 텐서플로우의 경우 tf.where(조건, True, False)

- 토치의 경우 torch.where(조건, True, False)

 

위 ELU 함수를 구현해보면: (토치 기준)

class ELU(nn.Module):
  def __init__(self):
    super(ELU, self).__init__()
    self.__name__ = 'ELU'
  
  def forward(self, x):
    return torch.where(x > 0, x, torch.exp(x) - 1)

와 같이 사용할 수 있다. 중첩해서도 사용 가능하다.

 

다만 앞서 말한 것 처럼 지수함수를 사용할 때, x 값에 민감하게 미분 값이 변화하여 nan 손실 함수 값을 리턴할 때가 있는데, 이 경우는 where 함수로 해결할 수 없다. (구현 시 마스킹 논리의 문제인 것으로 보인다.)

 

이런 경우 mask 변수에 조건문 Boolean 값을 인가하고, 이를 사용해 where 함수 역할을 대체할 수 있다. 

코드는 아래와 같다:

class ELU(nn.Module):
  def __init__(self):
    super(ELU, self).__init__()
    self.__name__ = 'ELU'
  
  def forward(self, x):
    mask = (x > 0)
    return mask * x + (~mask * (t.exp(~mask * x) - 1))

mask의 조건문이 True면 1이 되어 해당 항만 살리고, False이면 ~mask가 1이 되어 해당 항만 살리는 논리이다.

 

exp() 인자에도 ~mask가 곱해지는 이유는 역전파 과정 때문인데,

(~mask * (t.exp(x) - 1))의 최종 결과가 ~mask로 인해  0이 된다 하더라도 중간 과정에 exp(x)가 포함되어 있기 때문에 세부 연산을 하나하나 역전파하게 되면 exp(x)의 변화율을 참조해야 하고, 그 과정에서 loss가 무한대(nan)로 튈 수 있는 경우가 생긴다.

때문에 exp(~mask * x)로 이러한 경우까지 마스킹을 해주는 것이다.

 

 

 

 

Reference

 

https://stackoverflow.com/questions/74989424/i-get-a-loss-nan-when-implementing-mish-from-the-scratch?noredirect=1#comment132337931_74989424 

 

I get a loss : nan when implementing mish from the scratch

I'm currently working on making a custom activation function using tf2 on python. model architecture: VGG 16, on CIFAR-10 epochs: 100 lr: 0.001 for initial 80 epochs, 0.0001 for 20 epochs optimizer...

stackoverflow.com

https://github.com/pytorch/pytorch/issues/68425

 

`torch.where` produces nan in backward pass for differentiable forward pass · Issue #68425 · pytorch/pytorch

To Reproduce In the example below, torch.where properly executes a differentiable forward pass but fails for calculating the correct gradient value: a = torch.tensor(100., requires_grad=True) b = t...

github.com