Skip to content

Commit

Permalink
Add hardcoded flint_mpn_aors_n for ARM and x86
Browse files Browse the repository at this point in the history
These are generated from `dev/gen_ARCH_aors.jl`.  Also add tests for it.
  • Loading branch information
Albin Ahlbäck committed Nov 28, 2024
1 parent 9b3d2b6 commit d09c98e
Show file tree
Hide file tree
Showing 8 changed files with 1,450 additions and 0 deletions.
94 changes: 94 additions & 0 deletions dev/gen_arm_aors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#
# Copyright (C) 2024 Albin Ahlbäck
#
# This file is part of FLINT.
#
# FLINT is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License (LGPL) as published
# by the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version. See <https://www.gnu.org/licenses/>.
#

# Generating routines for r <- a OP b, where OP is either + or -.
#
# This generation was constructed with processors with Apple silicon in mind.
# Processors decoding less than 6 operations per cycle, or few store and load
# units may have worse performance.

r = "rp"
a = "ap"
b = "bp"
rp(ix::Int) = "[$r,#$ix*8]"
ap(ix::Int) = "[$a,#$ix*8]"
bp(ix::Int) = "[$b,#$ix*8]"

sx = "sx" # Return value for carry or borrow
CC = "CC"

sp = ["s$ix" for ix in 0:14] # Scrap registers

# Writes assembly that should be preprocessed by M4.
function aors(n::Int)
_str = "PROLOGUE(flint_mpn_aors($n))\n"
function ldr(s0::String, s1::String)
_str *= "\tldr\t$s0, $s1\n"
end
function ldp(s0::String, s1::String, s2::String)
_str *= "\tldp\t$s0, $s1, $s2\n"
end
function str(s0::String, s1::String)
_str *= "\tstr\t$s0, $s1\n"
end
function stp(s0::String, s1::String, s2::String)
_str *= "\tstp\t$s0, $s1, $s2\n"
end
function OP(s0::String, s1::String, s2::String)
_str *= "\tOP\t$s0, $s1, $s2\n"
end
function OPC(s0::String, s1::String, s2::String)
_str *= "\tOPC\t$s0, $s1, $s2\n"
end
function cset(s0::String, s1::String)
_str *= "\tcset\t$s0, $s1\n"
end

sv = deepcopy(sp)
s(ix::Int) = sv[ix + 1]
function shift(sv::Vector{String})
sv[(end - 3):end], sv[1:(end - 4)] = sv[1:4], sv[5:end]
end

ldp( s(0), s(2), ap(0))
ldp( s(1), s(3), bp(0))
OP( s(0), s(0), s(1))
OPC( s(2), s(2), s(3))
stp( s(0), s(2), rp(0))

for ix in 1:(n ÷ 2 - 1)
shift(sv)
ldp( s(0), s(2), ap(2 * ix))
ldp( s(1), s(3), bp(2 * ix))
OPC( s(0), s(0), s(1))
OPC( s(2), s(2), s(3))
stp( s(0), s(2), rp(2 * ix))
end

if n % 2 == 1
ldr( s(4), ap(n - 1))
ldr( s(5), bp(n - 1))
OPC( s(4), s(4), s(5))
str( s(4), rp(n - 1))
end

cset( sx, CC)

_str *= "\tret\nEPILOGUE()\n"

return _str
end

function print_all_aors(nmax::Int = 16)
for n in 2:nmax
println(aors(n))
end
end
83 changes: 83 additions & 0 deletions dev/gen_x86_aors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#
# Copyright (C) 2024 Albin Ahlbäck
#
# This file is part of FLINT.
#
# FLINT is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License (LGPL) as published
# by the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version. See <https://www.gnu.org/licenses/>.
#

# Generating routines for r <- a OP b, where OP is either + or -.
#
# This generation was constructed with processors with descent schedulers in
# mind.

r = "rp"
a = "ap"
b = "bp"
rp(ix::Int) = "$ix*8($r)"
ap(ix::Int) = "$ix*8($a)"
bp(ix::Int) = "$ix*8($b)"

sx = "sx" # Return value for carry or borrow, i.e. %rax

R32(sx::String) = "R32($sx)"
R8(sx::String) = "R8($sx)"

sp = ["s$ix" for ix in 0:4] # Scrap registers

# Writes assembly that should be preprocessed by M4.
function aors(n::Int)
str = "\tALIGN(16)\nPROLOGUE(flint_mpn_aors($n))\n"
function mov(s0::String, s1::String)
str *= "\tmov\t$s0, $s1\n"
end
function xor(s0::String, s1::String)
str *= "\txor\t$s0, $s1\n"
end
function OP(s0::String, s1::String)
str *= "\tOP\t$s0, $s1\n"
end
function OPC(s0::String, s1::String)
str *= "\tOPC\t$s0, $s1\n"
end
function setc(s0::String)
str *= "\tsetc\t$s0\n"
end

sv = deepcopy(sp)
s(ix::Int) = sv[ix + 1]
function shift(sv::Vector{String})
sv[end], sv[1:end - 1] = sv[1], sv[2:end]
end

mov( ap(0), s(0))

mov( ap(1), s(1))
xor( R32(sx), R32(sx))
OP( bp(0), s(0))
mov( s(0), rp(0))

for ix in 1:(n - 2)
shift(sv)
mov( ap(ix + 1), s(1))
OPC( bp(ix), s(0))
mov( s(0), rp(ix))
end

OPC( bp(n - 1), s(1))
mov( s(1), rp(n - 1))
setc( R8(sx))

str *= "\tret\nEPILOGUE()\n"

return str
end

function print_all_aors(nmax::Int = 16)
for n in 2:nmax
println(aors(n))
end
end
41 changes: 41 additions & 0 deletions src/mpn_extras.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,25 +462,34 @@ mp_limb_t mpn_rsh1sub_n(mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);

/* multiplication (general) **************************************************/

/* NOTE: This is getting a bit messy. How can we clean this up? */
#if FLINT_HAVE_ASSEMBLY_x86_64_adx
# define FLINT_MPN_AORS_FUNC_TAB_WIDTH 17
# define FLINT_MPN_MUL_FUNC_TAB_WIDTH 17
# define FLINT_MPN_SQR_FUNC_TAB_WIDTH 14

# define FLINT_HAVE_AORS_FUNC(n) ((n) < FLINT_MPN_AORS_FUNC_TAB_WIDTH)
# define FLINT_HAVE_MUL_FUNC(n, m) ((n) <= 16)
# define FLINT_HAVE_MUL_N_FUNC(n) ((n) <= 16)
# define FLINT_HAVE_SQR_FUNC(n) ((n) <= FLINT_MPN_SQR_FUNC_TAB_WIDTH)

# define FLINT_MPN_ADD_HARD(rp, xp, yp, n) (flint_mpn_add_func_tab[n](rp, xp, yp))
# define FLINT_MPN_SUB_HARD(rp, xp, yp, n) (flint_mpn_sub_func_tab[n](rp, xp, yp))
# define FLINT_MPN_MUL_HARD(rp, xp, xn, yp, yn) (flint_mpn_mul_func_tab[xn][yn](rp, xp, yp))
# define FLINT_MPN_MUL_N_HARD(rp, xp, yp, n) (flint_mpn_mul_n_func_tab[n](rp, xp, yp))
# define FLINT_MPN_SQR_HARD(rp, xp, n) (flint_mpn_sqr_func_tab[n](rp, xp))
#elif FLINT_HAVE_ASSEMBLY_armv8
# define FLINT_MPN_AORS_FUNC_TAB_WIDTH 17
# define FLINT_MPN_MUL_FUNC_N_TAB_WIDTH 15
# define FLINT_MPN_SQR_FUNC_TAB_WIDTH 9

# define FLINT_HAVE_AORS_FUNC(n) ((n) < FLINT_MPN_AORS_FUNC_TAB_WIDTH)
# define FLINT_HAVE_MUL_FUNC(n, m) FLINT_HAVE_MUL_N_FUNC(n)
# define FLINT_HAVE_MUL_N_FUNC(n) ((n) <= FLINT_MPN_MUL_FUNC_N_TAB_WIDTH)
# define FLINT_HAVE_SQR_FUNC(n) ((n) <= FLINT_MPN_SQR_FUNC_TAB_WIDTH)

# define FLINT_MPN_ADD_HARD(rp, xp, yp, n) (flint_mpn_add_func_tab[n](rp, xp, yp))
# define FLINT_MPN_SUB_HARD(rp, xp, yp, n) (flint_mpn_sub_func_tab[n](rp, xp, yp))
# define FLINT_MPN_MUL_HARD(rp, xp, xn, yp, yn) (flint_mpn_mul_func_n_tab[xn](rp, xp, yp, yn))
# define FLINT_MPN_MUL_N_HARD(rp, xp, yp, n) (flint_mpn_mul_func_n_tab[n](rp, xp, yp, n))
# define FLINT_MPN_SQR_HARD(rp, xp, n) (flint_mpn_sqr_func_tab[n](rp, xp))
Expand All @@ -506,6 +515,16 @@ typedef mp_limb_t (* flint_mpn_mul_func_t)(mp_ptr, mp_srcptr, mp_srcptr);
typedef mp_limb_t (* flint_mpn_mul_func_n_t)(mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);
typedef mp_limb_t (* flint_mpn_sqr_func_t)(mp_ptr, mp_srcptr);

#ifdef FLINT_MPN_AORS_FUNC_TAB_WIDTH
# define FLINT_USE_AORS_FUNC_TAB 1
FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_add_func_tab[];
FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_sub_func_tab[];
#else
# define FLINT_HAVE_AORS_FUNC(n) 0
# define FLINT_MPN_ADD_HARD(rp, xp, yp, n) 0
# define FLINT_MPN_SUB_HARD(rp, xp, yp, n) 0
#endif

#ifdef FLINT_MPN_MUL_FUNC_N_TAB_WIDTH
FLINT_DLL extern const flint_mpn_mul_func_n_t flint_mpn_mul_func_n_tab[];
#else
Expand All @@ -522,6 +541,28 @@ mp_limb_t _flint_mpn_mul(mp_ptr r, mp_srcptr x, mp_size_t xn, mp_srcptr y, mp_si
void _flint_mpn_mul_n(mp_ptr r, mp_srcptr x, mp_srcptr y, mp_size_t n);
mp_limb_t _flint_mpn_sqr(mp_ptr r, mp_srcptr x, mp_size_t n);

MPN_EXTRAS_INLINE
mp_limb_t flint_mpn_add_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
{
FLINT_ASSERT(n >= 1);

if (FLINT_HAVE_AORS_FUNC(n))
return FLINT_MPN_ADD_HARD(rp, xp, yp, n);
else
return mpn_add_n(rp, xp, yp, n);
}

MPN_EXTRAS_INLINE
mp_limb_t flint_mpn_sub_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
{
FLINT_ASSERT(n >= 1);

if (FLINT_HAVE_AORS_FUNC(n))
return FLINT_MPN_SUB_HARD(rp, xp, yp, n);
else
return mpn_sub_n(rp, xp, yp, n);
}

MPN_EXTRAS_INLINE mp_limb_t
flint_mpn_mul(mp_ptr r, mp_srcptr x, mp_size_t xn, mp_srcptr y, mp_size_t yn)
{
Expand Down
88 changes: 88 additions & 0 deletions src/mpn_extras/aors_n.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
Copyright (C) 2024 Albin Ahlbäck
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "mpn_extras.h"

#define DECL_AORS(n) _DECL_AORS(n)
#define _DECL_AORS(n) \
mp_limb_t flint_mpn_add_##n(mp_ptr, mp_srcptr, mp_srcptr); \
mp_limb_t flint_mpn_sub_##n(mp_ptr, mp_srcptr, mp_srcptr)

#define ADD(n) _ADD(n)
#define _ADD(n) flint_mpn_add_##n
#define SUB(n) _SUB(n)
#define _SUB(n) flint_mpn_sub_##n

/* Herein we assume that x86 and ARM are equivalent. */
#if FLINT_HAVE_ASSEMBLY_x86_64_adx || FLINT_HAVE_ASSEMBLY_armv8
DECL_AORS(1);
DECL_AORS(2);
DECL_AORS(3);
DECL_AORS(4);
DECL_AORS(5);
DECL_AORS(6);
DECL_AORS(7);
DECL_AORS(8);
DECL_AORS(9);
DECL_AORS(10);
DECL_AORS(11);
DECL_AORS(12);
DECL_AORS(13);
DECL_AORS(14);
DECL_AORS(15);
DECL_AORS(16);

/* TODO: Should probably rename these types so to not have two different types.
* Probably something like `mpn_binary_h_func`, where `h` is for hardcoded. */
const flint_mpn_mul_func_t flint_mpn_add_func_tab[] =
{
NULL,
ADD(1),
ADD(2),
ADD(3),
ADD(4),
ADD(5),
ADD(6),
ADD(7),
ADD(8),
ADD(9),
ADD(10),
ADD(11),
ADD(12),
ADD(13),
ADD(14),
ADD(15),
ADD(16)
};

const flint_mpn_mul_func_t flint_mpn_sub_func_tab[] =
{
NULL,
SUB(1),
SUB(2),
SUB(3),
SUB(4),
SUB(5),
SUB(6),
SUB(7),
SUB(8),
SUB(9),
SUB(10),
SUB(11),
SUB(12),
SUB(13),
SUB(14),
SUB(15),
SUB(16)
};
#else
typedef int this_file_is_empty;
#endif
Loading

0 comments on commit d09c98e

Please sign in to comment.