Proof of concept, 25% faster than n log n, multiplication algorithm for 8x8

Hans Petter Selasky hps at selasky.org
Thu Oct 27 19:56:20 CEST 2022


Hi,

 From what I know using variants of the fast fourier transform gives n 
log n performing algorithms. In the case of 8x8 that give a complexity 
of 3 * 8 = 24. The algorithm shown below needs only 18 non-complex 
multiplications to complete an 8 by 8 multiplication, similarly to a FFT 
transform. Currently looking into scalability :-)

I tried to search the litterature, but to no avail :-(

cat << EOF > test.cpp
/*-
  * Copyright (c) 2022 Hans Petter Selasky
  *
  * Redistribution and use in source and binary forms, with or without
  * modification, are permitted provided that the following conditions
  * are met:
  * 1. Redistributions of source code must retain the above copyright
  *    notice, this list of conditions and the following disclaimer.
  * 2. Redistributions in binary form must reproduce the above copyright
  *    notice, this list of conditions and the following disclaimer in the
  *    documentation and/or other materials provided with the distribution.
  *
  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 
PURPOSE
  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
CONSEQUENTIAL
  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 
STRICT
  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 
ANY WAY
  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  * SUCH DAMAGE.
  */

#include <gmpxx.h>
#include <iostream>

typedef mpz_class d_t;

static void
convert_4to8_fwd(const d_t *src, d_t *var)
{
         var[0] = src[0] + src[1] + src[2] + src[3] ;
         var[1] = - 4 * src[1] - 8 * src[2] - 12 * src[3] ;
         var[2] = src[0] - src[1] + src[2] - src[3] ;
         var[3] = 6 * src[1] - 12 * src[2] + 18 * src[3] ;
         var[4] = src[0] - src[2] ;
         var[5] = 14 * src[2] ;
         var[6] = - 7 * src[1] + 21 * src[3] ;
         var[7] = - src[1] + src[3] ;
};

static void
convert_8to8_fwd(const d_t *src, d_t *var)
{
         var[0] = src[0] + src[1] + src[2] + src[3] + src[4] + src[5] + 
src[6] + src[7] ;
         var[1] = - 4 * src[1] - 8 * src[2] - 12 * src[3] - 16 * src[4] 
- 20 * src[5] - 24 * src[6] - 28 * src[7] ;
         var[2] = src[0] - src[1] + src[2] - src[3] + src[4] - src[5] + 
src[6] - src[7] ;
         var[3] = 6 * src[1] - 12 * src[2] + 18 * src[3] - 24 * src[4] + 
30 * src[5] - 36 * src[6] + 42 * src[7] ;
         var[4] = src[0] - src[2] + src[4] - src[6] ;
         var[5] = 14 * src[2] - 28 * src[4] + 42 * src[6] ;
         var[6] = - 7 * src[1] + 21 * src[3] - 35 * src[5] + 49 * src[7] ;
         var[7] = - src[1] + src[3] - src[5] + src[7] ;
};

static void
multiply_8by8(const d_t *sa, const d_t *sb, d_t *dst)
{
         dst[0] = + sa[0] * sb[0] ;
         dst[1] = + sa[0] * sb[1] + sa[1] * sb[0] ;
         dst[2] = + sa[2] * sb[2] ;
         dst[3] = + sa[2] * sb[3] + sa[3] * sb[2] ;
         dst[4] = + sa[4] * sb[4] - sa[7] * sb[7] ;
         dst[6] = + sa[4] * sb[6] + sa[6] * sb[4] - sa[7] * sb[5] - 
sa[5] * sb[7] ;
         dst[7] = + sa[4] * sb[7] + sa[7] * sb[4] ;
         dst[5] = + sa[4] * sb[5] + sa[6] * sb[7] + sa[7] * sb[6] + 
sa[5] * sb[4] ;
};

static void
convert_8to8_inv(const d_t *src, d_t *var)
{
         var[0] = (+ 336 * src[0] + 21 * src[1] + 336 * src[2] + 14 * 
src[3] + 672 * src[4] + 24 * src[5] ) / 1344;
         var[1] = (+ 420 * src[0] + 21 * src[1] - 420 * src[2] - 14 * 
src[3] + 24 * src[6] - 840 * src[7] ) / 1344;
         var[2] = (+ 504 * src[0] + 21 * src[1] + 504 * src[2] + 14 * 
src[3] - 1008 * src[4] - 24 * src[5] ) / 1344;
         var[3] = (+ 588 * src[0] + 21 * src[1] - 588 * src[2] - 14 * 
src[3] - 24 * src[6] + 1176 * src[7] ) / 1344;
         var[4] = (- 21 * src[1] - 14 * src[3] - 24 * src[5] ) / 1344;
         var[5] = (- 84 * src[0] - 21 * src[1] + 84 * src[2] + 14 * 
src[3] - 24 * src[6] + 168 * src[7] ) / 1344;
         var[6] = (- 168 * src[0] - 21 * src[1] - 168 * src[2] - 14 * 
src[3] + 336 * src[4] + 24 * src[5] ) / 1344;
         var[7] = (- 252 * src[0] - 21 * src[1] + 252 * src[2] + 14 * 
src[3] + 24 * src[6] - 504 * src[7] ) / 1344;
};

int main()
{
	for (int a = 0; a != 8; a++) {
		for (int b = 0; b != 8; b++) {
			d_t sa[8];
			d_t sb[8];
			d_t sa8[8];
			d_t sb8[8];
			d_t r8[8];
			d_t res[8];

			sa[a] = 1;
			sb[b] = 1;

			printf("%d x %d = ", a, b);

			convert_8to8_fwd(sa, sa8);
			convert_8to8_fwd(sb, sb8);

			multiply_8by8(sa8, sb8, r8);

			convert_8to8_inv(r8, res);

			for (int c = 0; c != 8; c++) {
				std::cout << res[c] << ", ";
			}
			std::cout << "\n";
		}
	}

	for (int a = 0; a != 8; a++) {
		for (int b = 0; b != 8; b++) {
			d_t sa[8];
			d_t sb[8];
			d_t sa8[8];
			d_t sb8[8];
			d_t r8[8];
			d_t res[8];

			sa[0] = (a & 1) ? 1 : 0;
			sa[1] = (a & 2) ? 1 : 0;
			sa[2] = (a & 4) ? 1 : 0;

			sb[0] = (b & 1) ? 1 : 0;
			sb[1] = (b & 2) ? 1 : 0;
			sb[2] = (b & 4) ? 1 : 0;

			printf("%d x %d = ", a, b);

			convert_8to8_fwd(sa, sa8);
			convert_8to8_fwd(sb, sb8);

			multiply_8by8(sa8, sb8, r8);

			convert_8to8_inv(r8, res);

			d_t sum;
			
			for (int c = 0; c != 8; c++) {
				std::cout << res[c] << ", ";
				sum += res[c] << c;
			}
			std::cout << " S=" << sum << "\n";
		}
	}
	return (0);
}
EOF

clang++ -L/usr/local/lib -lgmpxx -lgmp -I/usr/local/include -O2 test.cpp 
&& ./a.out

--HPS


More information about the gmp-discuss mailing list