Proof of concept, 25% faster than n log n, multiplication algorithm for 8x8
Hans Petter Selasky
hps at selasky.org
Thu Oct 27 19:56:20 CEST 2022
Hi,
From what I know using variants of the fast fourier transform gives n
log n performing algorithms. In the case of 8x8 that give a complexity
of 3 * 8 = 24. The algorithm shown below needs only 18 non-complex
multiplications to complete an 8 by 8 multiplication, similarly to a FFT
transform. Currently looking into scalability :-)
I tried to search the litterature, but to no avail :-(
cat << EOF > test.cpp
/*-
* Copyright (c) 2022 Hans Petter Selasky
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*/
#include <gmpxx.h>
#include <iostream>
typedef mpz_class d_t;
static void
convert_4to8_fwd(const d_t *src, d_t *var)
{
var[0] = src[0] + src[1] + src[2] + src[3] ;
var[1] = - 4 * src[1] - 8 * src[2] - 12 * src[3] ;
var[2] = src[0] - src[1] + src[2] - src[3] ;
var[3] = 6 * src[1] - 12 * src[2] + 18 * src[3] ;
var[4] = src[0] - src[2] ;
var[5] = 14 * src[2] ;
var[6] = - 7 * src[1] + 21 * src[3] ;
var[7] = - src[1] + src[3] ;
};
static void
convert_8to8_fwd(const d_t *src, d_t *var)
{
var[0] = src[0] + src[1] + src[2] + src[3] + src[4] + src[5] +
src[6] + src[7] ;
var[1] = - 4 * src[1] - 8 * src[2] - 12 * src[3] - 16 * src[4]
- 20 * src[5] - 24 * src[6] - 28 * src[7] ;
var[2] = src[0] - src[1] + src[2] - src[3] + src[4] - src[5] +
src[6] - src[7] ;
var[3] = 6 * src[1] - 12 * src[2] + 18 * src[3] - 24 * src[4] +
30 * src[5] - 36 * src[6] + 42 * src[7] ;
var[4] = src[0] - src[2] + src[4] - src[6] ;
var[5] = 14 * src[2] - 28 * src[4] + 42 * src[6] ;
var[6] = - 7 * src[1] + 21 * src[3] - 35 * src[5] + 49 * src[7] ;
var[7] = - src[1] + src[3] - src[5] + src[7] ;
};
static void
multiply_8by8(const d_t *sa, const d_t *sb, d_t *dst)
{
dst[0] = + sa[0] * sb[0] ;
dst[1] = + sa[0] * sb[1] + sa[1] * sb[0] ;
dst[2] = + sa[2] * sb[2] ;
dst[3] = + sa[2] * sb[3] + sa[3] * sb[2] ;
dst[4] = + sa[4] * sb[4] - sa[7] * sb[7] ;
dst[6] = + sa[4] * sb[6] + sa[6] * sb[4] - sa[7] * sb[5] -
sa[5] * sb[7] ;
dst[7] = + sa[4] * sb[7] + sa[7] * sb[4] ;
dst[5] = + sa[4] * sb[5] + sa[6] * sb[7] + sa[7] * sb[6] +
sa[5] * sb[4] ;
};
static void
convert_8to8_inv(const d_t *src, d_t *var)
{
var[0] = (+ 336 * src[0] + 21 * src[1] + 336 * src[2] + 14 *
src[3] + 672 * src[4] + 24 * src[5] ) / 1344;
var[1] = (+ 420 * src[0] + 21 * src[1] - 420 * src[2] - 14 *
src[3] + 24 * src[6] - 840 * src[7] ) / 1344;
var[2] = (+ 504 * src[0] + 21 * src[1] + 504 * src[2] + 14 *
src[3] - 1008 * src[4] - 24 * src[5] ) / 1344;
var[3] = (+ 588 * src[0] + 21 * src[1] - 588 * src[2] - 14 *
src[3] - 24 * src[6] + 1176 * src[7] ) / 1344;
var[4] = (- 21 * src[1] - 14 * src[3] - 24 * src[5] ) / 1344;
var[5] = (- 84 * src[0] - 21 * src[1] + 84 * src[2] + 14 *
src[3] - 24 * src[6] + 168 * src[7] ) / 1344;
var[6] = (- 168 * src[0] - 21 * src[1] - 168 * src[2] - 14 *
src[3] + 336 * src[4] + 24 * src[5] ) / 1344;
var[7] = (- 252 * src[0] - 21 * src[1] + 252 * src[2] + 14 *
src[3] + 24 * src[6] - 504 * src[7] ) / 1344;
};
int main()
{
for (int a = 0; a != 8; a++) {
for (int b = 0; b != 8; b++) {
d_t sa[8];
d_t sb[8];
d_t sa8[8];
d_t sb8[8];
d_t r8[8];
d_t res[8];
sa[a] = 1;
sb[b] = 1;
printf("%d x %d = ", a, b);
convert_8to8_fwd(sa, sa8);
convert_8to8_fwd(sb, sb8);
multiply_8by8(sa8, sb8, r8);
convert_8to8_inv(r8, res);
for (int c = 0; c != 8; c++) {
std::cout << res[c] << ", ";
}
std::cout << "\n";
}
}
for (int a = 0; a != 8; a++) {
for (int b = 0; b != 8; b++) {
d_t sa[8];
d_t sb[8];
d_t sa8[8];
d_t sb8[8];
d_t r8[8];
d_t res[8];
sa[0] = (a & 1) ? 1 : 0;
sa[1] = (a & 2) ? 1 : 0;
sa[2] = (a & 4) ? 1 : 0;
sb[0] = (b & 1) ? 1 : 0;
sb[1] = (b & 2) ? 1 : 0;
sb[2] = (b & 4) ? 1 : 0;
printf("%d x %d = ", a, b);
convert_8to8_fwd(sa, sa8);
convert_8to8_fwd(sb, sb8);
multiply_8by8(sa8, sb8, r8);
convert_8to8_inv(r8, res);
d_t sum;
for (int c = 0; c != 8; c++) {
std::cout << res[c] << ", ";
sum += res[c] << c;
}
std::cout << " S=" << sum << "\n";
}
}
return (0);
}
EOF
clang++ -L/usr/local/lib -lgmpxx -lgmp -I/usr/local/include -O2 test.cpp
&& ./a.out
--HPS
More information about the gmp-discuss
mailing list