Skip to content

Commit

Permalink
added pow function
Browse files Browse the repository at this point in the history
  • Loading branch information
Geolm committed Feb 1, 2024
1 parent 0b65dcd commit a30bdd2
Show file tree
Hide file tree
Showing 3 changed files with 393 additions and 34 deletions.
220 changes: 219 additions & 1 deletion math_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ extern "C" {
// max error : 4.768371582e-07
float32x4_t vcbrtq_f32(float32x4_t a);

// max error : 1.484901873e-07
float32x4_t vpowq_f32(float32x4_t x, float32x4_t y);

#define __MATH__INTRINSICS__NEON__

#else
Expand Down Expand Up @@ -93,6 +96,9 @@ extern "C" {
// max error : 4.768371582e-07
__m256 mm256_cbrt_ps(__m256 a);

// max error : 1.484901873e-07
__m256 mm256_pow_ps(__m256 x, __m256 y);

#define __MATH__INTRINSICS__AVX__

#endif
Expand Down Expand Up @@ -167,11 +173,22 @@ extern "C" {
static inline simd_vectori simd_andnot_i(simd_vectori a, simd_vectori b) {return vbicq_s32(a, b);}
static inline simd_vectori simd_cmp_eq_i(simd_vectori a, simd_vectori b) {return vceqq_s32(a, b);}
static inline simd_vectori simd_cmp_gt_i(simd_vectori a, simd_vectori b) {return vcgtq_s32(a, b);}
static inline simd_vectori simd_min_i(simd_vectori a, simd_vectori b) {return vminq_s32(a, b);}
static inline simd_vectori simd_max_i(simd_vectori a, simd_vectori b) {return vmaxq_s32(a, b);}
static inline simd_vector simd_gather(const float* array, simd_vectori indices)
{
float tmp[4] = {array[indices[0]], array[indices[1]], array[indices[2]], array[indices[3]]};
return vld1q_f32(tmp);
}

#define simd_asin vasinq_f32
#define simd_atan vatanq_f32
#define simd_sincos vsincosq_f32
#define simd_sin vsinq_f32
#define simd_log vlogq_f32
#define simd_exp vexpq_f32
#define simd_log2 vlog2q_f32
#define simd_exp2 vexp2q_f32

#else
typedef __m256 simd_vector;
Expand Down Expand Up @@ -231,17 +248,25 @@ extern "C" {
static inline simd_vectori simd_splat_i(int i) {return _mm256_set1_epi32(i);}
static inline simd_vectori simd_splat_zero_i(void) {return _mm256_setzero_si256();}
static inline simd_vectori simd_shift_left_i(simd_vectori a, int i) {return _mm256_slli_epi32(a, i);}
static inline simd_vectori simd_shift_right_i(simd_vectori a, int i) {return _mm256_srli_epi32(a, i);}
static inline simd_vectori simd_shift_right_i(simd_vectori a, int i) {return _mm256_srai_epi32(a, i);}
static inline simd_vectori simd_and_i(simd_vectori a, simd_vectori b) {return _mm256_and_si256(a, b);}
static inline simd_vectori simd_or_i(simd_vectori a, simd_vectori b) {return _mm256_or_si256(a, b);}
static inline simd_vectori simd_andnot_i(simd_vectori a, simd_vectori b) {return _mm256_andnot_si256(b, a);}
static inline simd_vectori simd_cmp_eq_i(simd_vectori a, simd_vectori b) {return _mm256_cmpeq_epi32(a, b);}
static inline simd_vectori simd_cmp_gt_i(simd_vectori a, simd_vectori b) {return _mm256_cmpgt_epi32(a, b);}
static inline simd_vectori simd_min_i(simd_vectori a, simd_vectori b) {return _mm256_min_epi32(a, b);}
static inline simd_vectori simd_max_i(simd_vectori a, simd_vectori b) {return _mm256_max_epi32(a, b);}
static inline simd_vector simd_gather(const float* array, simd_vectori indices) {return _mm256_i32gather_ps(array, indices, 4);}


#define simd_asin mm256_asin_ps
#define simd_atan mm256_atan_ps
#define simd_sincos mm256_sincos_ps
#define simd_sin mm256_sin_ps
#define simd_exp mm256_exp_ps
#define simd_log mm256_log_ps
#define simd_exp2 mm256_exp2_ps
#define simd_log2 mm256_log2_ps

#endif

Expand Down Expand Up @@ -309,6 +334,10 @@ static inline simd_vector simd_sign(simd_vector a)
return simd_select(result, simd_splat( 1.f), simd_cmp_gt(a, simd_splat_zero()));
}

static inline simd_vectori simd_select_i(simd_vectori a, simd_vectori b, simd_vectori mask) { return simd_or_i(simd_andnot_i(a, mask), simd_and_i(b, mask));}
static inline simd_vectori simd_neg_i(simd_vectori a){return simd_sub_i(simd_splat_zero_i(), a);}


//----------------------------------------------------------------------------------------------------------------------
// based on http://gruntthepeon.free.fr/ssemath/
#ifdef __MATH__INTRINSICS__NEON__
Expand Down Expand Up @@ -810,6 +839,195 @@ __m256 mm256_cos_ps(__m256 x)
return x;
}

static inline simd_vector reduc(simd_vector x) {return simd_mul(simd_splat(0.0625f), simd_floor( simd_mul(simd_splat(16.f),x)));}

//----------------------------------------------------------------------------------------------------------------------
// based on https://github.com/jeremybarnes/cephes/blob/master/single/powf.c
#ifdef __MATH__INTRINSICS__NEON__
float32x4_t vpowq_f32(float32x4_t x, float32x4_t y)
#else
__m256 mm256_pow_ps(__m256 x, __m256 y)
#endif
{
simd_vector x_equals_zero = simd_cmp_eq(x, simd_splat_zero());
simd_vector y_equals_zero = simd_cmp_eq(y, simd_splat_zero());
simd_vector non_integer_power = simd_cmp_neq(y, simd_floor(y));
simd_vector return_zero = simd_andnot(x_equals_zero, y_equals_zero);
simd_vector return_one = simd_and(x_equals_zero, y_equals_zero);
simd_vector return_nan = simd_and(simd_cmp_lt(x, simd_splat_zero()), non_integer_power);

#ifdef __MATH_INTRINSINCS_FAST__
simd_vector z = simd_exp2(simd_mul(y, simd_log2(x)));
#else
// 2^(-i/16) The decimal values are rounded to 24-bit precision
static float A[] =
{
1.00000000000000000000E0f,
9.57603275775909423828125E-1f,
9.17004048824310302734375E-1f,
8.78126084804534912109375E-1f,
8.40896427631378173828125E-1f,
8.05245161056518554687500E-1f,
7.71105408668518066406250E-1f,
7.38413095474243164062500E-1f,
7.07106769084930419921875E-1f,
6.77127778530120849609375E-1f,
6.48419797420501708984375E-1f,
6.20928883552551269531250E-1f,
5.94603538513183593750000E-1f,
5.69394290447235107421875E-1f,
5.45253872871398925781250E-1f,
5.22136867046356201171875E-1f,
5.00000000000000000000E-1f
};

// continuation, for even i only 2^(i/16) = A[i] + B[i/2]
static float B[] =
{
0.00000000000000000000E0f,
-5.61963907099083340520586E-9f,
-1.23776636307969995237668E-8f,
4.03545234539989593104537E-9f,
1.21016171044789693621048E-8f,
-2.00949968760174979411038E-8f,
1.89881769396087499852802E-8f,
-6.53877009617774467211965E-9f,
0.00000000000000000000E0f
};

// 1 / A[i] The decimal values are full precision
static float Ainv[] =
{
1.00000000000000000000000E0f,
1.04427378242741384032197E0f,
1.09050773266525765920701E0f,
1.13878863475669165370383E0f,
1.18920711500272106671750E0f,
1.24185781207348404859368E0f,
1.29683955465100966593375E0f,
1.35425554693689272829801E0f,
1.41421356237309504880169E0f,
1.47682614593949931138691E0f,
1.54221082540794082361229E0f,
1.61049033194925430817952E0f,
1.68179283050742908606225E0f,
1.75625216037329948311216E0f,
1.83400808640934246348708E0f,
1.91520656139714729387261E0f,
2.00000000000000000000000E0f
};

simd_vector neg_x = simd_andnot(simd_cmp_lt(x, simd_splat_zero()), non_integer_power);

x = simd_select(x, simd_neg(x), neg_x);

// separate significand from exponent
simd_vector e;
x = simd_frexp(x, &e);

// find significand in antilog table A[]
simd_vectori i = simd_splat_i(1);
i = simd_select_i(i, simd_splat_i(9), simd_cast_from_float(simd_cmp_le(x, simd_splat(A[9]))));
simd_vectori i_plus_4 = simd_add_i(i, simd_splat_i(4));
i = simd_select_i(i, i_plus_4, simd_cast_from_float(simd_cmp_le(x, simd_gather(A, i_plus_4))));
simd_vectori i_plus_2 = simd_add_i(i, simd_splat_i(2));
i = simd_select_i(i, i_plus_2, simd_cast_from_float(simd_cmp_le(x, simd_gather(A, i_plus_2))));
i = simd_select_i(i, simd_splat_i(-1), simd_cast_from_float(simd_cmp_ge(x, simd_splat(A[1]))));
i = simd_add_i(i, simd_splat_i(1));

// Find (x - A[i])/A[i]
// in order to compute log(x/A[i]):
// log(x) = log( a x/a ) = log(a) + log(x/a)
// log(x/a) = log(1+v), v = x/a - 1 = (x-a)/a
x = simd_sub(x, simd_gather(A, i));
x = simd_sub(x, simd_gather(B, simd_shift_right_i(i, 1)));
x = simd_mul(x, simd_gather(Ainv, i));

// rational approximation for log(1+v):
// log(1+v) = v - 0.5 v^2 + v^3 P(v)
// Theoretical relative error of the approximation is 3.5e-11
// on the interval 2^(1/16) - 1 > v > 2^(-1/16) - 1
simd_vector z = simd_mul(x, x);
simd_vector w = simd_polynomial4(x, (float[]){-0.1663883081054895f, 0.2003770364206271f, -0.2500006373383951f, 0.3333331095506474f});
w = simd_mul(w, x);
w = simd_fmad(w, z, simd_mul(simd_splat(-.5f), z));

// Convert to base 2 logarithm: multiply by log2(e)
simd_vector LOG2EA = simd_splat(0.44269504088896340736F);
w = simd_fmad(w, LOG2EA, w);

// Note x was not yet added in to above rational approximation,
// so do it now, while multiplying by log2(e).
z = simd_fmad(x, LOG2EA, w);
z = simd_add(z, x);

// Compute exponent term of the base 2 logarithm.
w = simd_neg(simd_convert_from_int(i));
w = simd_fmad(w, simd_splat(0.0625f), e);

// Multiply base 2 log by y, in extended precision.
// separate y into large part ya and small part yb less than 1/16
simd_vector ya = reduc(y);
simd_vector yb = simd_sub(y, ya);

simd_vector W = simd_fmad(z, y, simd_mul(w, yb));
simd_vector Wa = reduc(W);
simd_vector Wb = simd_sub(W, Wa);

W = simd_fmad(w, ya, Wa);
Wa = reduc(W);
simd_vector u = simd_sub(W, Wa);

W = simd_add(Wb, u);
Wb = reduc(W);
w = simd_mul(simd_splat(16.f), simd_add(Wa, Wb));

return_zero = simd_or(return_zero, simd_cmp_lt(w, simd_splat(-2400.0f)));

e = w;
Wb = simd_sub(W, Wb);

simd_vector gt_zero = simd_cmp_gt(Wb, simd_splat_zero());
e = simd_select(e, simd_add(e, simd_splat(1.f)), gt_zero);
Wb = simd_select(Wb, simd_sub(Wb, simd_splat(0.0625f)), gt_zero);

// Now the product y * log2(x) = Wb + e/16.0.
// Compute base 2 exponential of Wb,
// where -0.0625 <= Wb <= 0.
// Theoretical relative error of the approximation is 2.8e-12.
// z = 2**Wb - 1
z = simd_polynomial4(Wb, (float[]) {9.416993633606397E-003f, 5.549356188719141E-002f, 2.402262883964191E-001f, 6.931471791490764E-001f});
z = simd_mul(z, Wb);

simd_vector neg_e = simd_cmp_lt(e, simd_splat_zero());
simd_vectori int_e = simd_convert_from_float(e);

simd_vectori i0 = simd_neg_i(simd_shift_right_i(simd_neg_i(int_e), 4));
simd_vectori i1 = simd_add_i(simd_shift_right_i(int_e, 4), simd_splat_i(1));
i = simd_select_i(i1, i0, simd_cast_from_float(neg_e));

int_e = simd_sub_i(simd_shift_left_i(i, 4), int_e);

// clamp int_e to avoid reading data out of the array
int_e = simd_min_i(int_e, simd_splat_i(16));
int_e = simd_max_i(int_e, simd_splat_i(0));

w = simd_gather(A, int_e);
z = simd_fmad(w, z, w); // 2^-e * ( 1 + (2^Hb-1) )
z = simd_ldexp(z, simd_convert_from_int(i));

// For negative x, find out if the integer exponent is odd or even.
w = simd_mul(simd_splat(2.f), simd_floor(simd_mul(simd_splat(.5f), w)));
z = simd_select(z, simd_neg(z), simd_and(neg_x, simd_cmp_neq(w, y)));
#endif

z = simd_andnot(z, return_zero);
z = simd_select(z, simd_splat(1.f), return_one);
z = simd_or(z, return_nan);

return z;
}

#endif // __MATH__INTRINSICS__IMPLEMENTATION__


Loading

0 comments on commit a30bdd2

Please sign in to comment.