Fast Approximations

Approximation the first

A quadratic approximation of exp ( x ) for nonnegative x near zero is

f + ( x ) = 1 + x + x 2 / 2 .

Noting that exp ( x ) = 1 / exp ( x ) gives us

f ( x ) = [ 1 / f + ( x ) when x < 0 otherwise f + ( x ) ] = [ 1 / ( 1 x + x 2 / 2 ) when x < 0 otherwise 1 + x + x 2 / 2 ] .

Noting that e x = ( e x / a ) a gives us

g ( x ) = f ( x / 4 ) 4 = ( f ( x / 4 ) 2 ) 2 .

This can be seen here.

A straightforward implementation parallelizes well:

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  for i in 0..16 {
    let rr = x[i] * 0.25;
    let sq = rr * rr * 0.5;
    let pos = (1.0 + rr) + sq;
    let neg = (1.0 - rr) + sq;
    let fst = if x[i] < 0.0 { neg.recip() } else { pos };
    let snd = fst * fst;
    y[i] = snd * snd;
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x3e800000
.CONST1: .long 0x3f000000
.CONST2: .long 0x3f800000
vexpps:
  vmovups zmm0, zmmword ptr [rdi]
  vbroadcastss  zmm2, dword ptr [rip + .CONST2]
  vxorps  xmm5, xmm5, xmm5
  vmulps  zmm1, zmm0, dword ptr [rip + .CONST0]{1to16}
  vcmpltps  k1, zmm0, zmm5
  vmulps  zmm3, zmm1, zmm1
  vmulps  zmm3, zmm3, dword ptr [rip + .CONST1]{1to16}
  vaddps  zmm4, zmm1, zmm2
  vsubps  zmm0, zmm2, zmm1
  vaddps  zmm4, zmm4, zmm3
  vaddps  zmm0, zmm0, zmm3
  vdivps  zmm4  {k1}, zmm2, zmm0
  vmulps  zmm0, zmm4, zmm4
  vmulps  zmm0, zmm0, zmm0
  vmovups zmmword ptr [rsi], zmm0
  vzeroupper
  ret

For comparison, I believe the implementation of exp used in Rust can be seen here, which is not parallelized by the compiler.

The approximation has middling accuracy:

 input        approx       actual      abs error    rel error
−7.000000    +0.002977    +0.000912    +0.002065    +2.264217
−6.000000    +0.005791    +0.002479    +0.003312    +1.336334
−5.000000    +0.011844    +0.006738    +0.005106    +0.757864
−4.000000    +0.025600    +0.018316    +0.007284    +0.397713
−3.000000    +0.058742    +0.049787    +0.008955    +0.179859
−2.000000    +0.143412    +0.135335    +0.008077    +0.059682
−1.000000    +0.371077    +0.367879    +0.003198    +0.008693
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.694856    +2.718282    −0.023426    −0.008618
+2.000000    +6.972900    +7.389056    −0.416156    −0.056321
+3.000000   +17.023682   +20.085537    −3.061855    −0.152441
+4.000000   +39.062500   +54.598148   −15.535648    −0.284545
+5.000000   +84.428101  +148.413162   −63.985062    −0.431128
+6.000000  +172.676025  +403.428802  −230.752777    −0.571979
+7.000000  +335.955963 +1096.633179  −760.677246    −0.693648

Approximation the second

Adding a third-order term and reducing the range further, seen here,

f + ( x ) = 1 + x + x 2 / 2 + x 3 / 6 g ( x ) = f ( x / 8 ) 8 = ( ( f ( x / 8 ) 2 ) 2 ) 2

improves this at the cost of three additional multiplications:

 input        approx       actual      abs error    rel error
−7.000000    +0.001006    +0.000912    +0.000095    +0.103716
−6.000000    +0.002628    +0.002479    +0.000149    +0.060299
−5.000000    +0.006951    +0.006738    +0.000213    +0.031565
−4.000000    +0.018574    +0.018316    +0.000259    +0.014124
−3.000000    +0.050031    +0.049787    +0.000244    +0.004906
−2.000000    +0.135480    +0.135335    +0.000145    +0.001068
−1.000000    +0.367906    +0.367879    +0.000027    +0.000073
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.718082    +2.718282    −0.000200    −0.000074
+2.000000    +7.381175    +7.389056    −0.007882    −0.001067
+3.000000   +19.987476   +20.085537    −0.098061    −0.004882
+4.000000   +53.837746   +54.598148    −0.760403    −0.013927
+5.000000  +143.871826  +148.413162    −4.541336    −0.030599
+6.000000  +380.485840  +403.428802   −22.942963    −0.056870
+7.000000  +993.582458 +1096.633179  −103.050720    −0.093970

Approximation the third

However, we can do better if we are willing to abuse the format of ieee 754 single-precision floating-point numbers. The following is from A Fast, Compact Approximation of the Exponential Function by Nicol Schraudolph, which can be read here. My thanks to dear Cosmo, who found this and implemented it.

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  const A : f32 = ((1 << (f32::MANTISSA_DIGITS - 1)) as f32)
                / std::f32::consts::LN_2;
  const B : f32 = (127u32 << (f32::MANTISSA_DIGITS - 1)) as f32;

  for i in 0..16 {
    let aff = x[i].mul_add(A, B);
    let rnd = unsafe { aff.to_int_unchecked::<u32>() };
    y[i] = f32::from_bits(rnd);
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x4b38aa3b
.CONST1: .long 0x4e7e0000
vexpps:
  vmovups      zmm1, zmmword ptr [rdi]
  vbroadcastss zmm0, dword ptr [rip + .CONST0]
  vfmadd213ps  zmm0, zmm1, dword ptr [rip + .CONST1]{1to16}
  vcvttps2udq  zmm0, zmm0
  vmovups      zmmword ptr [rsi], zmm0
  vzeroupper
  ret
 input        approx       actual      abs error    rel error
−7.000000    +0.000928    +0.000912    +0.000016    +0.017994
−6.000000    +0.002625    +0.002479    +0.000146    +0.058864
−5.000000    +0.006979    +0.006738    +0.000241    +0.035716
−4.000000    +0.019207    +0.018316    +0.000891    +0.048641
−3.000000    +0.052247    +0.049787    +0.002460    +0.049415
−2.000000    +0.139326    +0.135335    +0.003991    +0.029488
−1.000000    +0.389326    +0.367879    +0.021447    +0.058298
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.885376    +2.718282    +0.167094    +0.061471
+2.000000    +7.541565    +7.389056    +0.152509    +0.020640
+3.000000   +21.249268   +20.085537    +1.163731    +0.057939
+4.000000   +56.665039   +54.598148    +2.066891    +0.037856
+5.000000  +155.324219  +148.413162    +6.911057    +0.046566
+6.000000  +423.980469  +403.428802   +20.551666    +0.050942
+7.000000 +1125.234375 +1096.633179   +28.601196    +0.026081
››› to-do ‹‹‹
write an intuitive explanation for why it is a piecewise linear function

It is effectively a piecewise linear function, as can be seen here, and therefore its derivative is a piecewise step function,

1 log 2 × exp ( floor ( x / log 2 ) × log 2 ) ,

as can be seen here.

Because this approximation is a piecewise linear function, the error is not monotonic on either side of zero, and there are regions where even the quadratic approximation has smaller error. The first two approximations are usually better than this approximation near zero, but this behaves well over the entire domain.

Approximation the fourth

We can improve the accuracy of this approximation by again making use of the fact that exp ( x ) = 1 / exp ( x ) . If f is the piecewise linear function, then

g ( x ) = ( f ( x ) + 1 / f ( x ) ) / 2 .

This can be optimized to the following. My thanks to Jonathan Hallström, who made this observation and implemented it. (The use of intrinsics is unfortunately necessary to coax the compiler to emit vrcp14ps.)

pub fn vexpps(x : &[f32; 16], y : &mut [f32; 16])
{
  unsafe {
    use std::arch::x86_64::{                 // I think you mean “std::arch::amd64”
      __m512              as f32x16,
      _mm512_set1_ps      as broadcast,
      _mm512_fmadd_ps     as mul_add,
      _mm512_fnmadd_ps    as neg_mul_add,
      _mm512_cvtps_epi32  as convert,
      _mm512_castsi512_ps as transmute,
      _mm512_rcp14_ps     as recip,
      _mm512_add_ps       as add
    };
    let A  = broadcast(0x00800000 as f32 / std::f32::consts::LN_2); // 2²³ ÷ log 2
    let B0 = broadcast(0x3f000000 as f32);                      // 2²³ × (127 − 1)
    let B1 = broadcast(0x40000000 as f32);                      // 2²³ × (127 + 1)

    let xs = *(x.as_ptr() as *const f32x16);

    let aff0 = mul_add(xs, A, B0);
    let rnd0 = transmute(convert(aff0));

    let aff1 = neg_mul_add(xs, A, B1);
    let rnd1 = transmute(convert(aff1));
    let inv1 = recip(rnd1);

    let avg = add(rnd0, inv1);
    *y = *(&raw const avg as *const [f32; 16]);
  }
}
; rustc --edition=2024 --codegen opt-level=3 --codegen target-cpu=znver5
.CONST0: .long 0x4b38aa3b
.CONST1: .long 0x4e7c0000
.CONST2: .long 0x4e800000
vexpps:
  vmovaps      zmm0, zmmword ptr [rdi]
  vbroadcastss zmm1, dword ptr [rip + .CONST0]
  vbroadcastss zmm2, dword ptr [rip + .CONST1]
  vfmadd231ps  zmm2, zmm0, zmm1
  vfnmadd213ps zmm1, zmm0, dword ptr [rip + .CONST2]{1to16}
  vcvtps2dq    zmm2, zmm2
  vcvtps2dq    zmm0, zmm1
  vrcp14ps     zmm0, zmm0
  vaddps       zmm0, zmm0, zmm2
  vmovups      zmmword ptr [rsi], zmm0
  vzeroupper
  ret
 input        approx       actual      abs error    rel error
−7.000000    +0.000909    +0.000912    −0.000003    −0.003707
−6.000000    +0.002492    +0.002479    +0.000013    +0.005186
−5.000000    +0.006708    +0.006738    −0.000030    −0.004385
−4.000000    +0.018427    +0.018316    +0.000111    +0.006081
−3.000000    +0.049653    +0.049787    −0.000134    −0.002685
−2.000000    +0.135962    +0.135335    +0.000627    +0.004634
−1.000000    +0.367951    +0.367879    +0.000072    +0.000196
 0.000000    +1.000000    +1.000000     0.000000     0.000000
+1.000000    +2.726982    +2.718282    +0.008700    +0.003201
+2.000000    +7.359528    +7.389056    −0.029529    −0.003996
+3.000000   +20.194458   +20.085537    +0.108921    +0.005423
+4.000000   +54.365723   +54.598148    −0.232426    −0.004257
+5.000000  +149.310547  +148.413162    +0.897385    +0.006047
+6.000000  +402.488281  +403.428802    −0.940521    −0.002331
+7.000000 +1101.234375 +1096.633179    +4.601196    +0.004196

Adjusting the constants can reduce the error a skosh further, but zero then no longer maps to one.