blob: a8fa091a7cbf60ef6c408eca9888f51ad93ab7f4 [file] [log] [blame]
/*
* Vector math abstractions.
*
* Copyright (c) 2019-2023, Arm Limited.
* SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
*/
#ifndef _V_MATH_H
#define _V_MATH_H
#ifndef WANT_VMATH
/* Enable the build of vector math code. */
# define WANT_VMATH 1
#endif
#if WANT_VMATH
/* The goal of this header is to allow vector (only Neon for now)
and scalar build of the same algorithm. */
#if SCALAR
#define V_NAME(x) __s_##x
#elif VPCS && __aarch64__
#define V_NAME(x) __vn_##x
#define VPCS_ATTR __attribute__ ((aarch64_vector_pcs))
#else
#define V_NAME(x) __v_##x
#endif
#ifndef VPCS_ATTR
#define VPCS_ATTR
#endif
#ifndef VPCS_ALIAS
#define VPCS_ALIAS
#endif
#include <stdint.h>
#include "math_config.h"
typedef float f32_t;
typedef uint32_t u32_t;
typedef int32_t s32_t;
typedef double f64_t;
typedef uint64_t u64_t;
typedef int64_t s64_t;
/* reinterpret as type1 from type2. */
static inline u32_t
as_u32_f32 (f32_t x)
{
union { f32_t f; u32_t u; } r = {x};
return r.u;
}
static inline f32_t
as_f32_u32 (u32_t x)
{
union { u32_t u; f32_t f; } r = {x};
return r.f;
}
static inline s32_t
as_s32_u32 (u32_t x)
{
union { u32_t u; s32_t i; } r = {x};
return r.i;
}
static inline u32_t
as_u32_s32 (s32_t x)
{
union { s32_t i; u32_t u; } r = {x};
return r.u;
}
static inline u64_t
as_u64_f64 (f64_t x)
{
union { f64_t f; u64_t u; } r = {x};
return r.u;
}
static inline f64_t
as_f64_u64 (u64_t x)
{
union { u64_t u; f64_t f; } r = {x};
return r.f;
}
static inline s64_t
as_s64_u64 (u64_t x)
{
union { u64_t u; s64_t i; } r = {x};
return r.i;
}
static inline u64_t
as_u64_s64 (s64_t x)
{
union { s64_t i; u64_t u; } r = {x};
return r.u;
}
#if SCALAR
#define V_SUPPORTED 1
typedef f32_t v_f32_t;
typedef u32_t v_u32_t;
typedef s32_t v_s32_t;
typedef f64_t v_f64_t;
typedef u64_t v_u64_t;
typedef s64_t v_s64_t;
static inline int
v_lanes32 (void)
{
return 1;
}
static inline v_f32_t
v_f32 (f32_t x)
{
return x;
}
static inline v_u32_t
v_u32 (u32_t x)
{
return x;
}
static inline v_s32_t
v_s32 (s32_t x)
{
return x;
}
static inline f32_t
v_get_f32 (v_f32_t x, int i)
{
return x;
}
static inline u32_t
v_get_u32 (v_u32_t x, int i)
{
return x;
}
static inline s32_t
v_get_s32 (v_s32_t x, int i)
{
return x;
}
static inline void
v_set_f32 (v_f32_t *x, int i, f32_t v)
{
*x = v;
}
static inline void
v_set_u32 (v_u32_t *x, int i, u32_t v)
{
*x = v;
}
static inline void
v_set_s32 (v_s32_t *x, int i, s32_t v)
{
*x = v;
}
/* true if any elements of a v_cond result is non-zero. */
static inline int
v_any_u32 (v_u32_t x)
{
return x != 0;
}
/* to wrap the result of relational operators. */
static inline v_u32_t
v_cond_u32 (v_u32_t x)
{
return x ? -1 : 0;
}
static inline v_f32_t
v_abs_f32 (v_f32_t x)
{
return __builtin_fabsf (x);
}
static inline v_u32_t
v_bsl_u32 (v_u32_t m, v_u32_t x, v_u32_t y)
{
return (y & ~m) | (x & m);
}
static inline v_u32_t
v_cagt_f32 (v_f32_t x, v_f32_t y)
{
return fabsf (x) > fabsf (y);
}
/* to wrap |x| >= |y|. */
static inline v_u32_t
v_cage_f32 (v_f32_t x, v_f32_t y)
{
return fabsf (x) >= fabsf (y);
}
static inline v_u32_t
v_calt_f32 (v_f32_t x, v_f32_t y)
{
return fabsf (x) < fabsf (y);
}
static inline v_f32_t
v_div_f32 (v_f32_t x, v_f32_t y)
{
return x / y;
}
static inline v_f32_t
v_fma_f32 (v_f32_t x, v_f32_t y, v_f32_t z)
{
return __builtin_fmaf (x, y, z);
}
static inline v_f32_t
v_round_f32 (v_f32_t x)
{
return __builtin_roundf (x);
}
static inline v_s32_t
v_round_s32 (v_f32_t x)
{
return __builtin_lroundf (x); /* relies on -fno-math-errno. */
}
static inline v_f32_t
v_sel_f32 (v_u32_t p, v_f32_t x, v_f32_t y)
{
return p ? x : y;
}
static inline v_u32_t
v_sel_u32 (v_u32_t p, v_u32_t x, v_u32_t y)
{
return p ? x : y;
}
static inline v_f32_t
v_sqrt_f32 (v_f32_t x)
{
return __builtin_sqrtf (x);
}
/* convert to type1 from type2. */
static inline v_f32_t
v_to_f32_s32 (v_s32_t x)
{
return x;
}
static inline v_s32_t
v_to_s32_f32 (v_f32_t x)
{
return x;
}
static inline v_f32_t
v_to_f32_u32 (v_u32_t x)
{
return x;
}
/* reinterpret as type1 from type2. */
static inline v_u32_t
v_as_u32_f32 (v_f32_t x)
{
union { v_f32_t f; v_u32_t u; } r = {x};
return r.u;
}
static inline v_s32_t
v_as_s32_f32 (v_f32_t x)
{
union
{
v_f32_t f;
v_s32_t u;
} r = {x};
return r.u;
}
static inline v_f32_t
v_as_f32_u32 (v_u32_t x)
{
union { v_u32_t u; v_f32_t f; } r = {x};
return r.f;
}
static inline v_s32_t
v_as_s32_u32 (v_u32_t x)
{
union { v_u32_t u; v_s32_t i; } r = {x};
return r.i;
}
static inline v_u32_t
v_as_u32_s32 (v_s32_t x)
{
union { v_s32_t i; v_u32_t u; } r = {x};
return r.u;
}
static inline v_f32_t
v_lookup_f32 (const f32_t *tab, v_u32_t idx)
{
return tab[idx];
}
static inline v_u32_t
v_lookup_u32 (const u32_t *tab, v_u32_t idx)
{
return tab[idx];
}
static inline v_f32_t
v_call_f32 (f32_t (*f) (f32_t), v_f32_t x, v_f32_t y, v_u32_t p)
{
return f (x);
}
static inline v_f32_t
v_call2_f32 (f32_t (*f) (f32_t, f32_t), v_f32_t x1, v_f32_t x2, v_f32_t y,
v_u32_t p)
{
return f (x1, x2);
}
static inline int
v_lanes64 (void)
{
return 1;
}
static inline v_f64_t
v_f64 (f64_t x)
{
return x;
}
static inline v_u64_t
v_u64 (u64_t x)
{
return x;
}
static inline v_s64_t
v_s64 (s64_t x)
{
return x;
}
static inline f64_t
v_get_f64 (v_f64_t x, int i)
{
return x;
}
static inline void
v_set_f64 (v_f64_t *x, int i, f64_t v)
{
*x = v;
}
/* true if any elements of a v_cond result is non-zero. */
static inline int
v_any_u64 (v_u64_t x)
{
return x != 0;
}
/* true if all elements of a v_cond result is non-zero. */
static inline int
v_all_u64 (v_u64_t x)
{
return x;
}
/* to wrap the result of relational operators. */
static inline v_u64_t
v_cond_u64 (v_u64_t x)
{
return x ? -1 : 0;
}
static inline v_f64_t
v_abs_f64 (v_f64_t x)
{
return __builtin_fabs (x);
}
static inline v_u64_t
v_bsl_u64 (v_u64_t m, v_u64_t x, v_u64_t y)
{
return (y & ~m) | (x & m);
}
static inline v_u64_t
v_cagt_f64 (v_f64_t x, v_f64_t y)
{
return fabs (x) > fabs (y);
}
static inline v_f64_t
v_div_f64 (v_f64_t x, v_f64_t y)
{
return x / y;
}
static inline v_f64_t
v_fma_f64 (v_f64_t x, v_f64_t y, v_f64_t z)
{
return __builtin_fma (x, y, z);
}
static inline v_f64_t
v_min_f64(v_f64_t x, v_f64_t y) {
return x < y ? x : y;
}
static inline v_f64_t
v_round_f64 (v_f64_t x)
{
return __builtin_round (x);
}
static inline v_f64_t
v_sel_f64 (v_u64_t p, v_f64_t x, v_f64_t y)
{
return p ? x : y;
}
static inline v_f64_t
v_sqrt_f64 (v_f64_t x)
{
return __builtin_sqrt (x);
}
static inline v_s64_t
v_round_s64 (v_f64_t x)
{
return __builtin_lround (x); /* relies on -fno-math-errno. */
}
static inline v_u64_t
v_trunc_u64 (v_f64_t x)
{
return __builtin_trunc (x);
}
/* convert to type1 from type2. */
static inline v_f64_t
v_to_f64_s64 (v_s64_t x)
{
return x;
}
static inline v_f64_t
v_to_f64_u64 (v_u64_t x)
{
return x;
}
static inline v_s64_t
v_to_s64_f64 (v_f64_t x)
{
return x;
}
/* reinterpret as type1 from type2. */
static inline v_u64_t
v_as_u64_f64 (v_f64_t x)
{
union { v_f64_t f; v_u64_t u; } r = {x};
return r.u;
}
static inline v_f64_t
v_as_f64_u64 (v_u64_t x)
{
union { v_u64_t u; v_f64_t f; } r = {x};
return r.f;
}
static inline v_s64_t
v_as_s64_u64 (v_u64_t x)
{
union { v_u64_t u; v_s64_t i; } r = {x};
return r.i;
}
static inline v_u64_t
v_as_u64_s64 (v_s64_t x)
{
union { v_s64_t i; v_u64_t u; } r = {x};
return r.u;
}
static inline v_f64_t
v_lookup_f64 (const f64_t *tab, v_u64_t idx)
{
return tab[idx];
}
static inline v_u64_t
v_lookup_u64 (const u64_t *tab, v_u64_t idx)
{
return tab[idx];
}
static inline v_f64_t
v_call_f64 (f64_t (*f) (f64_t), v_f64_t x, v_f64_t y, v_u64_t p)
{
return f (x);
}
static inline v_f64_t
v_call2_f64 (f64_t (*f) (f64_t, f64_t), v_f64_t x1, v_f64_t x2, v_f64_t y,
v_u64_t p)
{
return f (x1, x2);
}
#elif __aarch64__
#define V_SUPPORTED 1
#include <arm_neon.h>
typedef float32x4_t v_f32_t;
typedef uint32x4_t v_u32_t;
typedef int32x4_t v_s32_t;
typedef float64x2_t v_f64_t;
typedef uint64x2_t v_u64_t;
typedef int64x2_t v_s64_t;
static inline int
v_lanes32 (void)
{
return 4;
}
static inline v_f32_t
v_f32 (f32_t x)
{
return (v_f32_t){x, x, x, x};
}
static inline v_u32_t
v_u32 (u32_t x)
{
return (v_u32_t){x, x, x, x};
}
static inline v_s32_t
v_s32 (s32_t x)
{
return (v_s32_t){x, x, x, x};
}
static inline f32_t
v_get_f32 (v_f32_t x, int i)
{
return x[i];
}
static inline u32_t
v_get_u32 (v_u32_t x, int i)
{
return x[i];
}
static inline s32_t
v_get_s32 (v_s32_t x, int i)
{
return x[i];
}
static inline void
v_set_f32 (v_f32_t *x, int i, f32_t v)
{
(*x)[i] = v;
}
static inline void
v_set_u32 (v_u32_t *x, int i, u32_t v)
{
(*x)[i] = v;
}
static inline void
v_set_s32 (v_s32_t *x, int i, s32_t v)
{
(*x)[i] = v;
}
/* true if any elements of a v_cond result is non-zero. */
static inline int
v_any_u32 (v_u32_t x)
{
/* assume elements in x are either 0 or -1u. */
return vpaddd_u64 (vreinterpretq_u64_u32 (x)) != 0;
}
/* to wrap the result of relational operators. */
static inline v_u32_t
v_cond_u32 (v_u32_t x)
{
return x;
}
static inline v_f32_t
v_abs_f32 (v_f32_t x)
{
return vabsq_f32 (x);
}
static inline v_u32_t
v_bsl_u32 (v_u32_t m, v_u32_t x, v_u32_t y)
{
return vbslq_u32 (m, x, y);
}
static inline v_u32_t
v_cagt_f32 (v_f32_t x, v_f32_t y)
{
return vcagtq_f32 (x, y);
}
/* to wrap |x| >= |y|. */
static inline v_u32_t
v_cage_f32 (v_f32_t x, v_f32_t y)
{
return vcageq_f32 (x, y);
}
static inline v_u32_t
v_calt_f32 (v_f32_t x, v_f32_t y)
{
return vcaltq_f32 (x, y);
}
static inline v_f32_t
v_div_f32 (v_f32_t x, v_f32_t y)
{
return vdivq_f32 (x, y);
}
static inline v_f32_t
v_fma_f32 (v_f32_t x, v_f32_t y, v_f32_t z)
{
return vfmaq_f32 (z, x, y);
}
static inline v_f32_t
v_round_f32 (v_f32_t x)
{
return vrndaq_f32 (x);
}
static inline v_s32_t
v_round_s32 (v_f32_t x)
{
return vcvtaq_s32_f32 (x);
}
static inline v_f32_t
v_sel_f32 (v_u32_t p, v_f32_t x, v_f32_t y)
{
return vbslq_f32 (p, x, y);
}
static inline v_u32_t
v_sel_u32 (v_u32_t p, v_u32_t x, v_u32_t y)
{
return vbslq_u32 (p, x, y);
}
static inline v_f32_t
v_sqrt_f32 (v_f32_t x)
{
return vsqrtq_f32 (x);
}
/* convert to type1 from type2. */
static inline v_f32_t
v_to_f32_s32 (v_s32_t x)
{
return (v_f32_t){x[0], x[1], x[2], x[3]};
}
static inline v_s32_t
v_to_s32_f32 (v_f32_t x)
{
return vcvtq_s32_f32 (x);
}
static inline v_f32_t
v_to_f32_u32 (v_u32_t x)
{
return (v_f32_t){x[0], x[1], x[2], x[3]};
}
/* reinterpret as type1 from type2. */
static inline v_u32_t
v_as_u32_f32 (v_f32_t x)
{
union { v_f32_t f; v_u32_t u; } r = {x};
return r.u;
}
static inline v_s32_t
v_as_s32_f32 (v_f32_t x)
{
union
{
v_f32_t f;
v_s32_t u;
} r = {x};
return r.u;
}
static inline v_f32_t
v_as_f32_u32 (v_u32_t x)
{
union { v_u32_t u; v_f32_t f; } r = {x};
return r.f;
}
static inline v_s32_t
v_as_s32_u32 (v_u32_t x)
{
union { v_u32_t u; v_s32_t i; } r = {x};
return r.i;
}
static inline v_u32_t
v_as_u32_s32 (v_s32_t x)
{
union { v_s32_t i; v_u32_t u; } r = {x};
return r.u;
}
static inline v_f32_t
v_lookup_f32 (const f32_t *tab, v_u32_t idx)
{
return (v_f32_t){tab[idx[0]], tab[idx[1]], tab[idx[2]], tab[idx[3]]};
}
static inline v_u32_t
v_lookup_u32 (const u32_t *tab, v_u32_t idx)
{
return (v_u32_t){tab[idx[0]], tab[idx[1]], tab[idx[2]], tab[idx[3]]};
}
static inline v_f32_t
v_call_f32 (f32_t (*f) (f32_t), v_f32_t x, v_f32_t y, v_u32_t p)
{
return (v_f32_t){p[0] ? f (x[0]) : y[0], p[1] ? f (x[1]) : y[1],
p[2] ? f (x[2]) : y[2], p[3] ? f (x[3]) : y[3]};
}
static inline v_f32_t
v_call2_f32 (f32_t (*f) (f32_t, f32_t), v_f32_t x1, v_f32_t x2, v_f32_t y,
v_u32_t p)
{
return (
v_f32_t){p[0] ? f (x1[0], x2[0]) : y[0], p[1] ? f (x1[1], x2[1]) : y[1],
p[2] ? f (x1[2], x2[2]) : y[2], p[3] ? f (x1[3], x2[3]) : y[3]};
}
static inline int
v_lanes64 (void)
{
return 2;
}
static inline v_f64_t
v_f64 (f64_t x)
{
return (v_f64_t){x, x};
}
static inline v_u64_t
v_u64 (u64_t x)
{
return (v_u64_t){x, x};
}
static inline v_s64_t
v_s64 (s64_t x)
{
return (v_s64_t){x, x};
}
static inline f64_t
v_get_f64 (v_f64_t x, int i)
{
return x[i];
}
static inline void
v_set_f64 (v_f64_t *x, int i, f64_t v)
{
(*x)[i] = v;
}
/* true if any elements of a v_cond result is non-zero. */
static inline int
v_any_u64 (v_u64_t x)
{
/* assume elements in x are either 0 or -1u. */
return vpaddd_u64 (x) != 0;
}
/* true if all elements of a v_cond result is 1. */
static inline int
v_all_u64 (v_u64_t x)
{
/* assume elements in x are either 0 or -1u. */
return vpaddd_s64 (vreinterpretq_s64_u64 (x)) == -2;
}
/* to wrap the result of relational operators. */
static inline v_u64_t
v_cond_u64 (v_u64_t x)
{
return x;
}
static inline v_f64_t
v_abs_f64 (v_f64_t x)
{
return vabsq_f64 (x);
}
static inline v_u64_t
v_bsl_u64 (v_u64_t m, v_u64_t x, v_u64_t y)
{
return vbslq_u64 (m, x, y);
}
static inline v_u64_t
v_cagt_f64 (v_f64_t x, v_f64_t y)
{
return vcagtq_f64 (x, y);
}
static inline v_f64_t
v_div_f64 (v_f64_t x, v_f64_t y)
{
return vdivq_f64 (x, y);
}
static inline v_f64_t
v_fma_f64 (v_f64_t x, v_f64_t y, v_f64_t z)
{
return vfmaq_f64 (z, x, y);
}
static inline v_f64_t
v_min_f64(v_f64_t x, v_f64_t y) {
return vminq_f64(x, y);
}
static inline v_f64_t
v_round_f64 (v_f64_t x)
{
return vrndaq_f64 (x);
}
static inline v_f64_t
v_sel_f64 (v_u64_t p, v_f64_t x, v_f64_t y)
{
return vbslq_f64 (p, x, y);
}
static inline v_f64_t
v_sqrt_f64 (v_f64_t x)
{
return vsqrtq_f64 (x);
}
static inline v_s64_t
v_round_s64 (v_f64_t x)
{
return vcvtaq_s64_f64 (x);
}
static inline v_u64_t
v_trunc_u64 (v_f64_t x)
{
return vcvtq_u64_f64 (x);
}
/* convert to type1 from type2. */
static inline v_f64_t
v_to_f64_s64 (v_s64_t x)
{
return (v_f64_t){x[0], x[1]};
}
static inline v_f64_t
v_to_f64_u64 (v_u64_t x)
{
return (v_f64_t){x[0], x[1]};
}
static inline v_s64_t
v_to_s64_f64 (v_f64_t x)
{
return vcvtq_s64_f64 (x);
}
/* reinterpret as type1 from type2. */
static inline v_u64_t
v_as_u64_f64 (v_f64_t x)
{
union { v_f64_t f; v_u64_t u; } r = {x};
return r.u;
}
static inline v_f64_t
v_as_f64_u64 (v_u64_t x)
{
union { v_u64_t u; v_f64_t f; } r = {x};
return r.f;
}
static inline v_s64_t
v_as_s64_u64 (v_u64_t x)
{
union { v_u64_t u; v_s64_t i; } r = {x};
return r.i;
}
static inline v_u64_t
v_as_u64_s64 (v_s64_t x)
{
union { v_s64_t i; v_u64_t u; } r = {x};
return r.u;
}
static inline v_f64_t
v_lookup_f64 (const f64_t *tab, v_u64_t idx)
{
return (v_f64_t){tab[idx[0]], tab[idx[1]]};
}
static inline v_u64_t
v_lookup_u64 (const u64_t *tab, v_u64_t idx)
{
return (v_u64_t){tab[idx[0]], tab[idx[1]]};
}
static inline v_f64_t
v_call_f64 (f64_t (*f) (f64_t), v_f64_t x, v_f64_t y, v_u64_t p)
{
return (v_f64_t){p[0] ? f (x[0]) : y[0], p[1] ? f (x[1]) : y[1]};
}
static inline v_f64_t
v_call2_f64 (f64_t (*f) (f64_t, f64_t), v_f64_t x1, v_f64_t x2, v_f64_t y,
v_u64_t p)
{
return (v_f64_t){p[0] ? f (x1[0], x2[0]) : y[0],
p[1] ? f (x1[1], x2[1]) : y[1]};
}
#endif
#endif
#endif