Some perfect_square_p experiments
jason
jason at insomnia247.nl
Tue Mar 19 05:04:22 CET 2024
I was playing with gmp's mpn_perfect_square_p, and found an alternative
implemementation. Unfortunately, it's basically the same speed and size
as the current one, so there's no reason to change. I'm just sending it
to the mailing list for archival purposes and in case anyone can find a way
to extract a net benefit from the ideas.
I did notice a few things in gen-psqr.c that could stand to be cleaned up.
I could send patches, but they're so simple and mechanical that doing
it yourself is probably easier than checking a patch:
1) The whole HAVE_CONST thing was deleted from gmp.h a long time ago.
It should be removed from here as well.
2) The various global variables and functions should be static.
3) The numerous instances of
printf("foo\n");
printf("bar %d\n", integer);
printf("baz %s\n", string);
should be merged to
printf("foo\n"
"bar %d\n"
"baz %s\n", integer, string);
Back to the number theory, mpn_perfect_square_p basically
does some modular congruence checks to see if the value
can be proved non-square modulo various small factors, then
goes ahead and computes a square root.
Phase 1 is to check the low bits modulo 256.
Phase 2 is to use mpn_mod_34lsub1 to quickly reduce
the operand modulo 2^48 - 1 or 2^24 - 1, and then
do some congruence checks on small factors of those
moduli. In particular, 2^24 - 1 is a multiple of 63, 65,
and 17. 2^48 - 1 has an additional factor of 2^24 + 1,
which has a small factor of 97.
The current code groups the prime factors into convenient
small factors, then computes the remainder modulo each
small factor with two multiplies, a mask, and a shift.
gen-psqr.c computes the necessary multipliers and bit
maps.
But if you don't re-group the prime factors, you have
nice convenient almost-powers-of-2, which can be reduced
with shift-and-add techniques like IP checksums.
My contributions are:
- A way to do tricks with the bit maps to handle a range
of slightly more than 64 remainders with a 64-bit map.
This both lets me handle mod-65 with a 64-bit map, and
allows imperfect modular reduction.
- A mixture of "fold by 2" (r = (r >> 24) + (r & 0xffffff)) and
"fold by 3" (r = ((r >> 48) + (r & 0xffffff)) + (r >> 24 & 0xffffff))
to reduce the value to the necessary range.
The latter is just as many operations as folding by 2
(2 shifts, 2 masks, and 2 additions) but its depth is
only 3 operations rather than 2*2 = 4.
My alternate implementation is in the perfsqr_alt_test()
function. The rest of the test program just compares
it to the existing gmplib code.
=== Standalone test program ===
#include <stdio.h>
#include <stdbool.h>
/*** gmp-impl.h subset ***/
typedef unsigned long long mp_limb_t;
#define GMP_LIMB_BITS 64
#define GMP_NAIL_BITS 0
#define GMP_NUMB_BITS (GMP_LIMB_BITS - GMP_NAIL_BITS)
#define CNST_LIMB(x) x##ull
#define ASSERT(expr) do {} while (0);
/*** mpn/perfsqr.h ***/
/* This file generated by gen-psqr.c - DO NOT EDIT. */
#if GMP_LIMB_BITS != 64 || GMP_NAIL_BITS != 0
Error, error, this data is for 64 bit limb and 0 bit nail
#endif
/* Non-zero bit indicates a quadratic residue mod 0x100.
This test identifies 82.81% as non-squares (212/256). */
#if 0
static const mp_limb_t
sq_res_0x100[4] = {
CNST_LIMB(0x202021202030213),
CNST_LIMB(0x202021202020213),
CNST_LIMB(0x202021202030212),
CNST_LIMB(0x202021202020212),
};
#endif
/* 2^48-1 = 17 * 63 * 65 * 97 ... */
#define PERFSQR_MOD_BITS 49
/* This test identifies 97.81% as non-squares. */
#define PERFSQR_MOD_TEST(r) \
do { \
/* 69.23% */ \
PERFSQR_MOD_2 (r, CNST_LIMB(91), CNST_LIMB(0xfd2fd2fd2fd3), \
CNST_LIMB(0x2191240), CNST_LIMB(0x8850a206953820e1)); \
\
/* 68.24% */ \
PERFSQR_MOD_2 (r, CNST_LIMB(85), CNST_LIMB(0xfcfcfcfcfcfd), \
CNST_LIMB(0x82158), CNST_LIMB(0x10b48c4b4206a105)); \
\
/* 55.56% */ \
PERFSQR_MOD_1 (r, CNST_LIMB( 9), CNST_LIMB(0xe38e38e38e39), \
CNST_LIMB(0x93)); \
\
/* 49.48% */ \
PERFSQR_MOD_2 (r, CNST_LIMB(97), CNST_LIMB(0xfd5c5f02a3a1), \
CNST_LIMB(0x1eb628b47), CNST_LIMB(0x6067981b8b451b5f)); \
} while (0)
/* This test identifies 97.81% as non-squares. */
#define PERFSQR_MOD_TEST2(r) \
do { \
/* 74.60% */ \
PERFSQR_MOD_1 (r, CNST_LIMB(63), CNST_LIMB(0xfbefbefbefbf), \
CNST_LIMB(0x20d2210498082481)); \
\
/* 67.69% */ \
PERFSQR_MOD_2 (r, CNST_LIMB(65), CNST_LIMB(0xfc0fc0fc0fc1), \
CNST_LIMB(0x0), CNST_LIMB(0x9614a0231014a1a5)); \
\
/* 49.48% */ \
PERFSQR_MOD_2 (r, CNST_LIMB(97), CNST_LIMB(0xfd5c5f02a3a1), \
CNST_LIMB(0x1eb628b47), CNST_LIMB(0x6067981b8b451b5f)); \
\
/* 47.06% */ \
PERFSQR_MOD_1 (r, CNST_LIMB(17), CNST_LIMB(0xf0f0f0f0f0f1), \
CNST_LIMB(0x1a317)); \
} while (0)
/* Grand total sq_res_0x100 and PERFSQR_MOD_TEST, 99.62% non-squares. */
/* helper for tests/mpz/t-perfsqr.c */
#define PERFSQR_DIVISORS { 256, 63, 65, 97, 17, }
/*** mpn/perfsqr.c **/
/* mpn_perfect_square_p(u,usize) -- Return non-zero if U is a perfect square,
zero otherwise. */
/* change this to "#define TRACE(x) x" for diagnostics */
#define TRACE(x)
/* PERFSQR_MOD_* detects non-squares using residue tests.
A macro PERFSQR_MOD_TEST is setup by gen-psqr.c in perfsqr.h. It takes
{up,usize} modulo a selected modulus to get a remainder r. For 32-bit or
64-bit limbs this modulus will be 2^24-1 or 2^48-1 using PERFSQR_MOD_34,
or for other limb or nail sizes a PERFSQR_PP is chosen and PERFSQR_MOD_PP
used. PERFSQR_PP_NORM and PERFSQR_PP_INVERTED are pre-calculated in this
case too.
PERFSQR_MOD_TEST then makes various calls to PERFSQR_MOD_1 or
PERFSQR_MOD_2 with divisors d which are factors of the modulus, and table
data indicating residues and non-residues modulo those divisors. The
table data is in 1 or 2 limbs worth of bits respectively, per the size of
each d.
A "modexact" style remainder is taken to reduce r modulo d.
PERFSQR_MOD_IDX implements this, producing an index "idx" for use with
the table data. Notice there's just one multiplication by a constant
"inv", for each d.
The modexact doesn't produce a true r%d remainder, instead idx satisfies
"-(idx<<PERFSQR_MOD_BITS) == r mod d". Because d is odd, this factor
-2^PERFSQR_MOD_BITS is a one-to-one mapping between r and idx, and is
accounted for by having the table data suitably permuted.
The remainder r fits within PERFSQR_MOD_BITS which is less than a limb.
In fact the GMP_LIMB_BITS - PERFSQR_MOD_BITS spare bits are enough to fit
each divisor d meaning the modexact multiply can take place entirely
within one limb, giving the compiler the chance to optimize it, in a way
that say umul_ppmm would not give.
There's no need for the divisors d to be prime, in fact gen-psqr.c makes
a deliberate effort to combine factors so as to reduce the number of
separate tests done on r. But such combining is limited to d <=
2*GMP_LIMB_BITS so that the table data fits in at most 2 limbs.
Alternatives:
It'd be possible to use bigger divisors d, and more than 2 limbs of table
data, but this doesn't look like it would be of much help to the prime
factors in the usual moduli 2^24-1 or 2^48-1.
The moduli 2^24-1 or 2^48-1 are nothing particularly special, they're
just easy to calculate (see mpn_mod_34lsub1) and have a nice set of prime
factors. 2^32-1 and 2^64-1 would be equally easy to calculate, but have
fewer prime factors.
The nails case usually ends up using mpn_mod_1, which is a lot slower
than mpn_mod_34lsub1. Perhaps other such special moduli could be found
for the nails case. Two-term things like 2^30-2^15-1 might be
candidates. Or at worst some on-the-fly de-nailing would allow the plain
2^24-1 to be used. Currently nails are too preliminary to be worried
about.
*/
#define PERFSQR_MOD_MASK ((CNST_LIMB(1) << PERFSQR_MOD_BITS) - 1)
#define MOD34_BITS (GMP_NUMB_BITS / 4 * 3)
#define MOD34_MASK ((CNST_LIMB(1) << MOD34_BITS) - 1)
#define PERFSQR_MOD_34(r, up, usize) \
do { \
(r) = mpn_mod_34lsub1 (up, usize); \
(r) = ((r) & MOD34_MASK) + ((r) >> MOD34_BITS); \
} while (0)
/* FIXME: The %= here isn't good, and might destroy any savings from keeping
the PERFSQR_MOD_IDX stuff within a limb (rather than needing umul_ppmm).
Maybe a new sort of mpn_preinv_mod_1 could accept an unnormalized divisor
and a shift count, like mpn_preinv_divrem_1. But mod_34lsub1 is our
normal case, so lets not worry too much about mod_1. */
#define PERFSQR_MOD_PP(r, up, usize) \
do { \
if (BELOW_THRESHOLD (usize, PREINV_MOD_1_TO_MOD_1_THRESHOLD)) \
{ \
(r) = mpn_preinv_mod_1 (up, usize, PERFSQR_PP_NORM, \
PERFSQR_PP_INVERTED); \
(r) %= PERFSQR_PP; \
} \
else \
{ \
(r) = mpn_mod_1 (up, usize, PERFSQR_PP); \
} \
} while (0)
#define PERFSQR_MOD_IDX(idx, r, d, inv) \
do { \
mp_limb_t q; \
ASSERT ((r) <= PERFSQR_MOD_MASK); \
ASSERT ((((inv) * (d)) & PERFSQR_MOD_MASK) == 1); \
ASSERT (MP_LIMB_T_MAX / (d) >= PERFSQR_MOD_MASK); \
\
q = ((r) * (inv)) & PERFSQR_MOD_MASK; \
ASSERT (r == ((q * (d)) & PERFSQR_MOD_MASK)); \
(idx) = (q * (d)) >> PERFSQR_MOD_BITS; \
} while (0)
#define PERFSQR_MOD_1(r, d, inv, mask) \
do { \
unsigned idx; \
ASSERT ((d) <= GMP_LIMB_BITS); \
PERFSQR_MOD_IDX(idx, r, d, inv); \
TRACE (printf (" PERFSQR_MOD_1 d=%u r=%lu idx=%u\n", \
d, r%d, idx)); \
if ((((mask) >> idx) & 1) == 0) \
{ \
TRACE (printf (" non-square\n")); \
return 0; \
} \
} while (0)
/* The expression "(int) idx - GMP_LIMB_BITS < 0" lets the compiler use the
sign bit from "idx-GMP_LIMB_BITS", which might help avoid a branch. */
#define PERFSQR_MOD_2(r, d, inv, mhi, mlo) \
do { \
mp_limb_t m; \
unsigned idx; \
ASSERT ((d) <= 2*GMP_LIMB_BITS); \
\
PERFSQR_MOD_IDX (idx, r, d, inv); \
TRACE (printf (" PERFSQR_MOD_2 d=%u r=%lu idx=%u\n", \
d, r%d, idx)); \
m = ((int) idx - GMP_LIMB_BITS < 0 ? (mlo) : (mhi)); \
idx %= GMP_LIMB_BITS; \
if (((m >> idx) & 1) == 0) \
{ \
TRACE (printf (" non-square\n")); \
return 0; \
} \
} while (0)
bool
perfsqr_mod_test(mp_limb_t r)
{
r = (r & MOD34_MASK) + (r >> MOD34_BITS);
PERFSQR_MOD_TEST(r);
return true;
}
bool
perfsqr_mod_test2(mp_limb_t r)
{
r = (r & MOD34_MASK) + (r >> MOD34_BITS);
PERFSQR_MOD_TEST2(r);
return true;
}
#if 0
mpn_perfect_square_p (mp_srcptr up, mp_size_t usize)
{
ASSERT (usize >= 1);
TRACE (gmp_printf ("mpn_perfect_square_p %Nd\n", up, usize));
/* The first test excludes 212/256 (82.8%) of the perfect square candidates
in O(1) time. */
{
unsigned idx = up[0] % 0x100;
if (((sq_res_0x100[idx / GMP_LIMB_BITS]
>> (idx % GMP_LIMB_BITS)) & 1) == 0)
return 0;
}
#if 0
/* Check that we have even multiplicity of 2, and then check that the rest is
a possible perfect square. Leave disabled until we can determine this
really is an improvement. If it is, it could completely replace the
simple probe above, since this should throw out more non-squares, but at
the expense of somewhat more cycles. */
{
mp_limb_t lo;
int cnt;
lo = up[0];
while (lo == 0)
up++, lo = up[0], usize--;
count_trailing_zeros (cnt, lo);
if ((cnt & 1) != 0)
return 0; /* return of not even multiplicity of 2 */
lo >>= cnt; /* shift down to align lowest non-zero bit */
if ((lo & 6) != 0)
return 0;
}
#endif
/* The second test uses mpn_mod_34lsub1 or mpn_mod_1 to detect non-squares
according to their residues modulo small primes (or powers of
primes). See perfsqr.h. */
PERFSQR_MOD_TEST (up, usize);
/* For the third and last test, we finally compute the square root,
to make sure we've really got a perfect square. */
{
mp_ptr root_ptr;
int res;
TMP_DECL;
TMP_MARK;
root_ptr = TMP_ALLOC_LIMBS ((usize + 1) / 2);
/* Iff mpn_sqrtrem returns zero, the square is perfect. */
res = ! mpn_sqrtrem (root_ptr, NULL, up, usize);
TMP_FREE;
return res;
}
}
#endif
bool
perfsqr_alt_test(mp_limb_t r)
{
// Step 1: Mod-4095. Reduce to mod-2^24-1, then
unsigned s = ((r >> 48) + (r & 0xffffff)) + (r >> 24 & 0xffffff);
// s <= 0xffff + 0xffffff + 0xffffff = 0x200fffd;
unsigned t = (s >> 12) + (s & 0xfff);
// t <= 0x200e + 0xfff = 0x300d = 030015
// Step 2: Mod-63 with an offset of 13
unsigned u = (t >> 6) + (t & 63) + 13;
// u <= 0277 + 077 = 0376
u = (u >> 6) + (u & 63);
// u <= 02 + 77 = 0101 = 65
// The bitmap of quadratic residues mod 63 looks like
// Possible residues mod 63: 0 1 4 7 9 16 18 22 25 28 36 37 43 46 49 58 (16/63)
// 0 1 2 3 4 5 6
// 0 0 0 0 0 0 012
// **..*..*.*......*.*...*..*..*.......**.....*..*..*........*....
// But if we add an offset of +13 to the index, that's taking the last
// 13 bits off the end and adding them to the beginning:
// ........*....**..*..*.*......*.*...*..*..*.......**.....*..*..*
// / / / / / / / / / / / / / / / /
// (In hex, that's 0x49060248a0526100.)
// At this point, consider two possible ways to extend this to higher
// residues. One is the correct way, mod-63:
// 0 1 2 3 4 5 6 7 7
// 0 0 0 0 0 0 0 3 0 5
// ........*....**..*..*.*......*.*...*..*..*.......**.....*..*..*........*....*
// But suppose we add a bit 63 = 0 and do it mod-64 instead:
// ........*....**..*..*.*......*.*...*..*..*.......**.....*..*..*.........*....*
// Because our mask starts with several 0 bits, these two agree until bit 71.
// Reducing mod-64 is free on most processors with bit shifts, so we can save
// on the fold-and-add reductions.
static const mp_limb_t mod63_mask = CNST_LIMB(0x49060248a0526100);
if ((mod63_mask >> (u % 64) & 1) == 0)
return false;
// Step 3: Mod-65, with no offset
u = ((t >> 12) + (t & 63)) - (t >> 6 & 63);
// u >= -63
// u <= 02 + 077 = 65
// Conditional add of 65. Shift amount is arbitrary as long as it's >= 6.
// We choose 8 because that might be optimized.
u += (u >> 8) & 65;
// The bitmap of quadratic residues mod 63 looks like
// Possible residues mod 65: 0 1 4 9 10 14 16 25 26 29 30 35 36 39 40 49 51 55 56 61 64 (21/65 = 32%)
// 0 1 2 3 4 5 6
// 0 0 0 0 0 0 01234
// **..*....**...*.*........**..**....**..**........*.*...**....*..*
// / / / / / / / / / / / / / / / /
// (In hex, that's 0x1218a019866014613.)
// Here, it's important to note that 0, 1 and 64 are all quadratic
// residues mod 65. Meaning that 65 and 66 are also quadratic residues.
// If we take the remainder mod 64 and look it up in the bit map,
// we'll find:
// 64 -> 0 -> residue (correct)
// 65 -> 1 -> residue (correct)
// 66 -> 2 -> non-residue (WRONG!)
// But still, this lets us handle remainders up to and including 65 with
// a 64-bit map.
static const mp_limb_t mod65_mask = CNST_LIMB(0x218a019866014613);
if ((mod65_mask >> (u % 64) & 1) == 0)
return false;
// Step 4: Mod-97
r = (r & MOD34_MASK) + (r >> MOD34_BITS);
// r <= 0xffff + 0xffffffffffff = 0x100000000FFFE;
PERFSQR_MOD_2 (r, CNST_LIMB(97), CNST_LIMB(0xfd5c5f02a3a1), \
CNST_LIMB(0x1eb628b47), CNST_LIMB(0x6067981b8b451b5f));
#if 0
/* 47.06% */
PERFSQR_MOD_1 (r, CNST_LIMB(17), CNST_LIMB(0xf0f0f0f0f0f1), \
CNST_LIMB(0x1a317));
#else
// Step 5: Mod-17
// Remember, s <= 0x200fffd;
t = ((s >> 16) + (s & 0xff)) + (s >> 8 & 0xff);
// t <= 0x1ff + 0xff + 0xff = 0x3fd;
t = ((t >> 8) + (t & 15)) - (t >> 4 & 15);
// t >= -15
// t <= 2 + 15 = 17
// Possible residues mod 17: 0 1 2 4 8 9 13 15 16 (9/17)
// ***.*...**...*.** 0..17
// *.*...**...*.** -15..-1
// / / / / / / / /
// Hex 0xd18ba317
//
// -15 = 2 is a quadratic residue
// 17 = 0 is a quadratic residue
// If we take this mod-32 rather than mod-17, -5 gets mapped to +17,
// and both are quadratic residues, so everything works.
unsigned mod17_mask = 0xd18ba317;
if ((mod17_mask >> (t % 32) & 1) == 0)
return false;
#endif
return true;
}
static bool do_compare(bool (*f)(mp_limb_t), bool (*g)(mp_limb_t))
{
mp_limb_t seed = 0;
for (int i = 0 ; i < 10000000; i++) {
seed = 5 * seed + 1; // The world's shittiest LCG
bool rf = f(seed);
bool rg = g(seed);
if (rf != rg) {
printf("Mismatch! f(%#llx) = %d != g(%llx) = %d\n",
seed, rf, seed, rg);
return false;
}
}
return true;
}
static unsigned time_random(bool (*f)(mp_limb_t))
{
mp_limb_t seed = 0;
unsigned count = 0;
unsigned long long ticks = __builtin_ia32_rdtsc();
for (int i = 0 ; i < 1000000; i++) {
seed = 5 * seed + 1; // The world's shittiest LCG
count += f(seed);
}
ticks = __builtin_ia32_rdtsc() - ticks;
printf("%u possible squares; %llu ticks\n", count, ticks);
return (unsigned)ticks;
}
static unsigned time_square(bool (*f)(mp_limb_t))
{
unsigned seed = 0;
unsigned count = 0;
unsigned long long ticks = __builtin_ia32_rdtsc();
for (int i = 0 ; i < 1000000; i++) {
seed = 5 * seed + 1; // The world's shittiest LCG
mp_limb_t sq = (mp_limb_t)seed * seed;
count += f(sq);
}
ticks = __builtin_ia32_rdtsc() - ticks;
printf("%u possible squares; %llu ticks\n", count, ticks);
return (unsigned)ticks;
}
void
do_test(bool (*f)(mp_limb_t))
{
unsigned sum1 = 0, sum2 = 0;
for (int i = 0; i < 10; i++) {
printf("Random: ");
sum1 += time_random(f);
printf("Square: ");
sum2 += time_square(f);
}
printf("Random total: %u\nSquare total: %u\n", sum1, sum2);
}
int
main(void)
{
if (!do_compare(perfsqr_mod_test, perfsqr_alt_test))
return 1;
if (!do_compare(perfsqr_mod_test, perfsqr_mod_test2))
return 1;
puts("mod_test");
do_test(perfsqr_mod_test);
puts("mod_test2");
do_test(perfsqr_mod_test2);
puts("alt_test");
do_test(perfsqr_alt_test);
puts("mod_test");
do_test(perfsqr_mod_test);
puts("mod_test2");
do_test(perfsqr_mod_test2);
puts("alt_test");
do_test(perfsqr_alt_test);
return 0;
}
More information about the gmp-devel
mailing list