Albin Ahlbäck albin.ahlback at gmail.com
Tue Feb 6 15:03:54 CET 2024

```Hello,

I just wrote an implementation for mpn_mulhigh_basecase for
instructions) based on Torbjörn's `mpn_mullo_basecase'.

It is currently declared on the form

mp_limb_t flint_mpn_mulhigh_basecase(mp_ptr rp, mp_srcptr xp, mp_srcptr
yp, mp_size_t n),

as it was developed for FLINT (Fast Library for Number Theory). Note
that `rp' cannot be aliased with `xp' or `yp'.

It will compute an approximation of the upper most `n + 1' limbs of
`{xp, n}' times `{yp, n}', where the upper most `n' limbs are pushed to
`rp[0]', ..., `rp[n - 1]' and the least significant computed limb is
returned (via %rax). This returned limb should have an error of
something along `n' ULPs.

Note that this differs from MPFR's (private) function `mpfr_mulhigh_n',
which computes the approximation of the upper most `n' limbs into
`rp[n]', ..., `rp[2 * n - 1]', where `rp[n]' has an error of something
along `n' ULPs at most.

Feel free to change it according to your needs (perhaps you do not want
to compute `n + 1' limbs, but rather `n' limbs).

If this code will be used in GMP, feel free to remove the copyright
claim for FLINT and put my name (spelled Albin Ahlbäck) in the GMP

Just some notes:

- We use our own M4 syntax for the beginning and ending of the function,
but it should be easy to translate to GMP's syntax.
- It currently only works for n > 5 (I believe) as we in FLINT have
specialized routines for small n.
- It would be nice to avoid pushing five register, and only push four.
- Reduce the size of the `L(end)' code, and try to avoid multiple jumps
therein.
- Move the code-snippet of `L(f2)' to just above `L(b2)', so that no
jump is needed in between. (This currently does not work because
`L(end)' as well as this code-snippet is too large for a relative 8-bit
jump.)
- Start out with an mul_1 sequence with just a mulx+add+adc chain, just
like in `mpn_mullo_basecase'.
- Remove the first multiplication in each `L(fX)' and put it in `L(end)'
- The `adcx' instruction in `L(fX)' can be removed (then one needs to
adjust the `L(bX)'-label), but I found it to be slower. Can we remove it
and somehow maintain the same performance?

Best,
Albin
-------------- next part --------------
dnl  X64-64 mpn_mullo_basecase optimised for Intel Broadwell.

dnl  Contributed to the GNU project by Torbjorn Granlund.

dnl  Copyright 2017 Free Software Foundation, Inc.

dnl  This file is part of the GNU MP Library.
dnl
dnl  The GNU MP Library is free software; you can redistribute it and/or modify
dnl  it under the terms of either:
dnl
dnl      Software Foundation; either version 3 of the License, or (at your
dnl      option) any later version.
dnl
dnl  or
dnl
dnl      Foundation; either version 2 of the License, or (at your option) any
dnl      later version.
dnl
dnl  or both in parallel, as here.
dnl
dnl  The GNU MP Library is distributed in the hope that it will be useful, but
dnl  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
dnl  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
dnl  for more details.
dnl
dnl  You should have received copies of the GNU General Public License and the
dnl  GNU Lesser General Public License along with the GNU MP Library.  If not,
dnl
dnl   Copyright (C) 2024 Albin Ahlbäck
dnl
dnl   This file is part of FLINT.
dnl
dnl   FLINT is free software: you can redistribute it and/or modify it under
dnl   the terms of the GNU Lesser General Public License (LGPL) as published
dnl   by the Free Software Foundation; either version 3 of the License, or
dnl

include(`config.m4')

define(`rp',	   `%rdi')
define(`ap',	   `%rsi')
define(`bp_param', `%rdx')
define(`n',	   `%rcx')

define(`bp',	   `%r8')
define(`jmpreg',   `%r9')
define(`nn',	   `%r10')
define(`m',	   `%r13')
define(`mm',	   `%r14')

define(`rx',	   `%rax')

define(`r0',	   `%r11')
define(`r1',	   `%rbx')
define(`r2',	   `%rbp')
define(`r3',	   `%r12')

# Idea: Do similar to mpn_mullo_basecase for Skylake.

.text

.global	FUNC(flint_mpn_mulhigh_basecase)
ALIGN(32)
TYPE(flint_mpn_mulhigh_basecase)

FUNC(flint_mpn_mulhigh_basecase):
.cfi_startproc
mov	bp_param, bp
lea	-1*8(ap,n,8), ap	# ap += n - 1

push	%rbx
push	%rbp
push	%r12
push	%r13
push	%r14

# Initial triangle
#       h
#     h x
#   h x x
# h x x x
# x x x x
define(`s0', `jmpreg')
define(`s1', `m')
define(`s2', `mm')
define(`s3', `nn')
mov	0*8(bp), %rdx
xor	R32(s3), R32(s3)
mulx	-1*8(ap), rx, rx
mulx	0*8(ap), s0, r0

mov	1*8(bp), %rdx
mulx	-2*8(ap), s1, s1
mulx	-1*8(ap), r3, r2
mulx	0*8(ap), s0, r1

mov	2*8(bp), %rdx
mulx	-3*8(ap), s0, s0
mulx	-2*8(ap), r3, s1
mulx	-1*8(ap), s0, s1
mulx	0*8(ap), %rdx, r2

mov	3*8(bp), %rdx
mulx	-4*8(ap), s1, s1
mulx	-3*8(ap), s0, s2
mulx	-2*8(ap), s1, r3
mulx	-1*8(ap), s2, s3
mulx	0*8(ap), %rdx, r3
mov	r0, 0*8(rp)
mov	r1, 1*8(rp)
mov	r2, 2*8(rp)
mov	r3, 3*8(rp)
undefine(`s0')
undefine(`s1')
undefine(`s2')
undefine(`s3')

# - m = -8 * n_cur	(n_cur is the 4 at the start)
# - mm = -8 * (n - 1)	(where n is the original n)
# - n keeps track of how many loops to do in the addmul-loop.
# - nn keeps track of initial n between loops.
lea	-1*8(,n,8), R32(mm)
lea	4*8(bp), bp
lea	-3*8(ap), ap
mov	\$-4*8, m		# m <- -8 * 4
neg	mm			# mm <- -8 * (n - 1)
mov	0*8(bp), %rdx
xor	R32(nn), R32(nn)	# nn <- 0
xor	R32(n), R32(n)		# n <- 0
mulx	-2*8(ap), r1, r1

L(f4):	mulx	-1*8(ap), r2, r3
mulx	0*8(ap), r0, r1
lea	3*8(ap), ap
lea	-5*8(rp), rp
lea	L(f5)(%rip), jmpreg
jmp	L(b4)

L(f0):	mulx	-1*8(ap), r2, r3
mulx	0*8(ap), r0, r1
lea	-1*8(ap), ap
lea	-1*8(rp), rp
lea	L(f1)(%rip), jmpreg
jmp	L(b0)

L(f1):	mulx	-1*8(ap), r0, r1
mulx	0*8(ap), r2, r3
lea	1(nn), R32(nn)
lea	1(n), R32(n)
lea	L(f2)(%rip), jmpreg
jmp	L(b1)

L(f7):	mulx	-1*8(ap), r0, r1
mulx	0*8(ap), r2, r3
lea	-2*8(ap), ap
lea	-2*8(rp), rp
lea	L(f0)(%rip), jmpreg
jmp	L(b7)

L(f2):	mulx	-1*8(ap), r2, r3
mulx	0*8(ap), r0, r1
lea	1*8(ap), ap
lea	1*8(rp), rp
mulx	0*8(ap), r2, r3
lea	L(f3)(%rip), jmpreg
jmp	L(b2)

mov	r2, 0*8(rp)
adox	n, r3		# n = 0
adc	n, r3		# n = 0
add	m, ap		# Reset ap
mov	r3, 1*8(rp)
lea	-1*8(m), m
lea	1*8(bp), bp	# Increase bp
lea	2*8(rp,m), rp	# Reset rp
mov	0*8(bp), %rdx	# Load bp
cmp	R32(m), R32(mm)
jge	L(jmp)
# If |m| < |mm|: goto jmpreg, but first do high part
or	R32(nn), R32(n)	# Reset n, CF and OF
mulx	-2*8(ap), r1, r1
jmp	*jmpreg
# If |m| > |mm|: goto fin
L(jmp):	jg	L(fin)
# If |m| = |mm|: goto jmpreg
or	R32(nn), R32(n)	# Reset n, clear CF and OF
jmp	*jmpreg

ALIGN(32)
mov	r0, -1*8(rp)
jrcxz	L(end)	# Jump if n = 0
L(b1):	mulx	1*8(ap), r0, r1
lea	-1(n), R32(n)
mov	r2, 0*8(rp)
L(b0):	mulx	2*8(ap), r2, r3
mov	r0, 1*8(rp)
L(b7):	mulx	3*8(ap), r0, r1
lea	8*8(ap), ap
mov	r2, 2*8(rp)
L(b6):	mulx	-4*8(ap), r2, r3
mov	r0, 3*8(rp)
L(b5):	mulx	-3*8(ap), r0, r1
mov	r2, 4*8(rp)
L(b4):	mulx	-2*8(ap), r2, r3
mov	r0, 5*8(rp)
mulx	-1*8(ap), r0, r1
mov	r2, 6*8(rp)
lea	8*8(rp), rp
mulx	0*8(ap), r2, r3
jmp	L(b2)

L(f6):	mulx	-1*8(ap), r2, r3
mulx	0*8(ap), r0, r1
lea	5*8(ap), ap
lea	-3*8(rp), rp
lea	L(f7)(%rip), jmpreg
jmp	L(b6)

L(f5):	mulx	-1*8(ap), r0, r1
mulx	0*8(ap), r2, r3
lea	4*8(ap), ap
lea	-4*8(rp), rp
lea	L(f6)(%rip), jmpreg
jmp	L(b5)

L(f3):	mulx	-1*8(ap), r0, r1
mulx	0*8(ap), r2, r3
lea	2*8(ap), ap
lea	-6*8(rp), rp
lea	L(f4)(%rip), jmpreg
jmp	L(b3)

L(fin):	pop	%r14
pop	%r13
pop	%r12
pop	%rbp
pop	%rbx

ret
.flint_mpn_mulhigh_basecase_end:
SIZE(flint_mpn_mulhigh_basecase, .flint_mpn_mulhigh_basecase_end)
.cfi_endproc
```