什么是GELU激活?


Answers:


19

GELU功能

我们可以展开的累积分布ñ01个,即ΦX,如下:

格鲁X:=XPXX=XΦX=0.5X1个+埃尔夫X2

请注意,这是一个定义,而不是方程式(或关系)。作者为此提议提供了一些理由,例如随机类比,但是从数学上讲,这只是一个定义。

这是GELU的图:

tanh近似

对于这些类型的数值逼近,关键思想是找到一个相似的函数(主要基于经验),对其进行参数化,然后将其拟合至原始函数中的一组点。

知道埃尔夫X非常接近X

erf的一阶导数x埃尔夫X2tanh相符2πXX=0,这是2π,我们继续拟合

2πX+一种X2+bX3+CX4+dX5
(或带有更多项)到一组点X一世埃尔夫X一世2

我已经安装此功能之间的20个样品-1.51.5使用本网站),这里是系数:

通过设置一种=C=d=0b估计为0.04495641。如果有更多样本处于更宽的范围内(该位置仅允许20个样本),系数b将更接近纸张的0.044715。最后我们得到

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

与均方误差108x[10,10]

请注意,如果我们未利用一阶导数之间的关系,则术语2π将被包含在参数中,如下所示:

0.5X1个+0.797885X+0.035677X3
,它不那么漂亮(分析性更强,数值更大)!

利用平价

正如@BookYourLuck所建议的,我们可以利用函数的奇偶性来限制搜索多项式的空间。也就是说,由于erf是奇函数,即f(x)=f(x),和tanh也是奇函数,多项式函数pol(x)tanh也应该是奇数(应该仅具有奇次幂x)有

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

以前,我们很幸运最后得到偶数幂x2x4(几乎)零系数,但是通常,这可能导致低质量近似,例如,像0.23x2这样的项被取消了。不用额外选择(偶数或奇数),而不是简单地选择0x2

乙状结肠逼近

erf(x2σX-1个210-4X[-1010]

这是用于生成数据点,拟合函数并计算均方误差的Python代码:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

输出:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
为什么需要近似值?他们不能只使用erf函数吗?
SebiSebi

8

首先要注意的是

ΦX=1个2Ë[RFC-X2=1个21个+Ë[RFX2
按平价 Ë[RF。我们需要证明
Ë[RFX22πX+一种X3
对于 一种0.044715

对于较大的值 X,两个函数都受限制 [-1个1个]。对于小X,各自的泰勒级数读为

X=X-X33+ØX3
Ë[RFX=2πX-X33+ØX3
代入,我们得到
2πX+一种X3=2πX+一种-23πX3+ØX3
Ë[RFX2=2πX-X36+ØX3
的等式系数 X3, 我们发现
一种0.04553992412
接近论文的 0.044715

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.