mpn_mulhigh_basecase for Broadwell
Albin Ahlbäck
albin.ahlback at gmail.com
Tue Feb 6 15:03:54 CET 2024
Hello,
I just wrote an implementation for mpn_mulhigh_basecase for
Broadwell-type processors (that is, x86_64 with BMI2 and ADX
instructions) based on Torbjörn's `mpn_mullo_basecase'.
It is currently declared on the form
mp_limb_t flint_mpn_mulhigh_basecase(mp_ptr rp, mp_srcptr xp, mp_srcptr
yp, mp_size_t n),
as it was developed for FLINT (Fast Library for Number Theory). Note
that `rp' cannot be aliased with `xp' or `yp'.
It will compute an approximation of the upper most `n + 1' limbs of
`{xp, n}' times `{yp, n}', where the upper most `n' limbs are pushed to
`rp[0]', ..., `rp[n - 1]' and the least significant computed limb is
returned (via %rax). This returned limb should have an error of
something along `n' ULPs.
Note that this differs from MPFR's (private) function `mpfr_mulhigh_n',
which computes the approximation of the upper most `n' limbs into
`rp[n]', ..., `rp[2 * n - 1]', where `rp[n]' has an error of something
along `n' ULPs at most.
Feel free to change it according to your needs (perhaps you do not want
to compute `n + 1' limbs, but rather `n' limbs).
If this code will be used in GMP, feel free to remove the copyright
claim for FLINT and put my name (spelled Albin Ahlbäck) in the GMP
copyright claim instead.
Just some notes:
- We use our own M4 syntax for the beginning and ending of the function,
but it should be easy to translate to GMP's syntax.
- It currently only works for n > 5 (I believe) as we in FLINT have
specialized routines for small n.
- It would be nice to avoid pushing five register, and only push four.
- Reduce the size of the `L(end)' code, and try to avoid multiple jumps
therein.
- Move the code-snippet of `L(f2)' to just above `L(b2)', so that no
jump is needed in between. (This currently does not work because
`L(end)' as well as this code-snippet is too large for a relative 8-bit
jump.)
- Start out with an mul_1 sequence with just a mulx+add+adc chain, just
like in `mpn_mullo_basecase'.
- Remove the first multiplication in each `L(fX)' and put it in `L(end)'
instead.
- The `adcx' instruction in `L(fX)' can be removed (then one needs to
adjust the `L(bX)'-label), but I found it to be slower. Can we remove it
and somehow maintain the same performance?
Best,
Albin
-------------- next part --------------
dnl X64-64 mpn_mullo_basecase optimised for Intel Broadwell.
dnl Contributed to the GNU project by Torbjorn Granlund.
dnl Copyright 2017 Free Software Foundation, Inc.
dnl This file is part of the GNU MP Library.
dnl
dnl The GNU MP Library is free software; you can redistribute it and/or modify
dnl it under the terms of either:
dnl
dnl * the GNU Lesser General Public License as published by the Free
dnl Software Foundation; either version 3 of the License, or (at your
dnl option) any later version.
dnl
dnl or
dnl
dnl * the GNU General Public License as published by the Free Software
dnl Foundation; either version 2 of the License, or (at your option) any
dnl later version.
dnl
dnl or both in parallel, as here.
dnl
dnl The GNU MP Library is distributed in the hope that it will be useful, but
dnl WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
dnl or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
dnl for more details.
dnl
dnl You should have received copies of the GNU General Public License and the
dnl GNU Lesser General Public License along with the GNU MP Library. If not,
dnl see https://www.gnu.org/licenses/.
dnl
dnl Copyright (C) 2024 Albin Ahlbäck
dnl
dnl This file is part of FLINT.
dnl
dnl FLINT is free software: you can redistribute it and/or modify it under
dnl the terms of the GNU Lesser General Public License (LGPL) as published
dnl by the Free Software Foundation; either version 3 of the License, or
dnl (at your option) any later version. See <https://www.gnu.org/licenses/>.
dnl
include(`config.m4')
include(`src/mpn_extras/broadwell/asm-defs.m4')
define(`rp', `%rdi')
define(`ap', `%rsi')
define(`bp_param', `%rdx')
define(`n', `%rcx')
define(`bp', `%r8')
define(`jmpreg', `%r9')
define(`nn', `%r10')
define(`m', `%r13')
define(`mm', `%r14')
define(`rx', `%rax')
define(`r0', `%r11')
define(`r1', `%rbx')
define(`r2', `%rbp')
define(`r3', `%r12')
# Idea: Do similar to mpn_mullo_basecase for Skylake.
.text
.global FUNC(flint_mpn_mulhigh_basecase)
ALIGN(32)
TYPE(flint_mpn_mulhigh_basecase)
FUNC(flint_mpn_mulhigh_basecase):
.cfi_startproc
mov bp_param, bp
lea -1*8(ap,n,8), ap # ap += n - 1
push %rbx
push %rbp
push %r12
push %r13
push %r14
# Initial triangle
# h
# h x
# h x x
# h x x x
# x x x x
define(`s0', `jmpreg')
define(`s1', `m')
define(`s2', `mm')
define(`s3', `nn')
mov 0*8(bp), %rdx
xor R32(s3), R32(s3)
mulx -1*8(ap), rx, rx
mulx 0*8(ap), s0, r0
add s0, rx
adc s3, r0
mov 1*8(bp), %rdx
mulx -2*8(ap), s1, s1
mulx -1*8(ap), r3, r2
mulx 0*8(ap), s0, r1
add r3, rx
adc s0, r0
adc s3, r1
add s1, rx
adc r2, r0
adc s3, r1
mov 2*8(bp), %rdx
mulx -3*8(ap), s0, s0
mulx -2*8(ap), r3, s1
add s0, rx
adc s1, r0
mulx -1*8(ap), s0, s1
mulx 0*8(ap), %rdx, r2
adc s1, r1
adc s3, r2
add r3, rx
adc s0, r0
adc %rdx, r1
adc s3, r2
mov 3*8(bp), %rdx
mulx -4*8(ap), s1, s1
mulx -3*8(ap), s0, s2
add s1, rx
adc s2, r0
mulx -2*8(ap), s1, r3
mulx -1*8(ap), s2, s3
adc r3, r1
adc s3, r2
mulx 0*8(ap), %rdx, r3
adc $0, r3
add s0, rx
adc s1, r0
adc s2, r1
mov r0, 0*8(rp)
mov r1, 1*8(rp)
adc %rdx, r2
adc $0, r3
mov r2, 2*8(rp)
mov r3, 3*8(rp)
undefine(`s0')
undefine(`s1')
undefine(`s2')
undefine(`s3')
# Addmul chains
# - m = -8 * n_cur (n_cur is the 4 at the start)
# - mm = -8 * (n - 1) (where n is the original n)
# - n keeps track of how many loops to do in the addmul-loop.
# - nn keeps track of initial n between loops.
lea -1*8(,n,8), R32(mm)
lea 4*8(bp), bp
lea -3*8(ap), ap
mov $-4*8, m # m <- -8 * 4
neg mm # mm <- -8 * (n - 1)
mov 0*8(bp), %rdx
xor R32(nn), R32(nn) # nn <- 0
xor R32(n), R32(n) # n <- 0
mulx -2*8(ap), r1, r1
adcx r1, rx
L(f4): mulx -1*8(ap), r2, r3
mulx 0*8(ap), r0, r1
adox r2, rx
adcx r3, r0
lea 3*8(ap), ap
lea -5*8(rp), rp
lea L(f5)(%rip), jmpreg
jmp L(b4)
L(f0): mulx -1*8(ap), r2, r3
mulx 0*8(ap), r0, r1
adox r2, rx
adcx r3, r0
lea -1*8(ap), ap
lea -1*8(rp), rp
lea L(f1)(%rip), jmpreg
jmp L(b0)
L(f1): mulx -1*8(ap), r0, r1
mulx 0*8(ap), r2, r3
adox r0, rx
adcx r1, r2
lea 1(nn), R32(nn)
lea 1(n), R32(n)
lea L(f2)(%rip), jmpreg
jmp L(b1)
L(f7): mulx -1*8(ap), r0, r1
mulx 0*8(ap), r2, r3
adox r0, rx
adcx r1, r2
lea -2*8(ap), ap
lea -2*8(rp), rp
lea L(f0)(%rip), jmpreg
jmp L(b7)
L(f2): mulx -1*8(ap), r2, r3
mulx 0*8(ap), r0, r1
adox r2, rx
adcx r3, r0
lea 1*8(ap), ap
lea 1*8(rp), rp
mulx 0*8(ap), r2, r3
lea L(f3)(%rip), jmpreg
jmp L(b2)
L(end): adox 0*8(rp), r2
mov r2, 0*8(rp)
adox n, r3 # n = 0
adc n, r3 # n = 0
add m, ap # Reset ap
mov r3, 1*8(rp)
lea -1*8(m), m
lea 1*8(bp), bp # Increase bp
lea 2*8(rp,m), rp # Reset rp
mov 0*8(bp), %rdx # Load bp
cmp R32(m), R32(mm)
jge L(jmp)
# If |m| < |mm|: goto jmpreg, but first do high part
or R32(nn), R32(n) # Reset n, CF and OF
mulx -2*8(ap), r1, r1
adcx r1, rx
jmp *jmpreg
# If |m| > |mm|: goto fin
L(jmp): jg L(fin)
# If |m| = |mm|: goto jmpreg
or R32(nn), R32(n) # Reset n, clear CF and OF
jmp *jmpreg
ALIGN(32)
L(b2): adox -1*8(rp), r0
adcx r1, r2
mov r0, -1*8(rp)
jrcxz L(end) # Jump if n = 0
L(b1): mulx 1*8(ap), r0, r1
adox 0*8(rp), r2
lea -1(n), R32(n)
mov r2, 0*8(rp)
adcx r3, r0
L(b0): mulx 2*8(ap), r2, r3
adcx r1, r2
adox 1*8(rp), r0
mov r0, 1*8(rp)
L(b7): mulx 3*8(ap), r0, r1
lea 8*8(ap), ap
adcx r3, r0
adox 2*8(rp), r2
mov r2, 2*8(rp)
L(b6): mulx -4*8(ap), r2, r3
adox 3*8(rp), r0
adcx r1, r2
mov r0, 3*8(rp)
L(b5): mulx -3*8(ap), r0, r1
adcx r3, r0
adox 4*8(rp), r2
mov r2, 4*8(rp)
L(b4): mulx -2*8(ap), r2, r3
adox 5*8(rp), r0
adcx r1, r2
mov r0, 5*8(rp)
L(b3): adox 6*8(rp), r2
mulx -1*8(ap), r0, r1
mov r2, 6*8(rp)
lea 8*8(rp), rp
adcx r3, r0
mulx 0*8(ap), r2, r3
jmp L(b2)
L(f6): mulx -1*8(ap), r2, r3
mulx 0*8(ap), r0, r1
adox r2, rx
adcx r3, r0
lea 5*8(ap), ap
lea -3*8(rp), rp
lea L(f7)(%rip), jmpreg
jmp L(b6)
L(f5): mulx -1*8(ap), r0, r1
mulx 0*8(ap), r2, r3
adox r0, rx
adcx r1, r2
lea 4*8(ap), ap
lea -4*8(rp), rp
lea L(f6)(%rip), jmpreg
jmp L(b5)
L(f3): mulx -1*8(ap), r0, r1
mulx 0*8(ap), r2, r3
adox r0, rx
adcx r1, r2
lea 2*8(ap), ap
lea -6*8(rp), rp
lea L(f4)(%rip), jmpreg
jmp L(b3)
L(fin): pop %r14
pop %r13
pop %r12
pop %rbp
pop %rbx
ret
.flint_mpn_mulhigh_basecase_end:
SIZE(flint_mpn_mulhigh_basecase, .flint_mpn_mulhigh_basecase_end)
.cfi_endproc
More information about the gmp-devel
mailing list