/* Implementation of QuickMul by Chen Li , 10/00. $Id$ */ #include "Integer.h" #include "qmul.h" const static int PRIM_ELE = 440564289; const static int REV_PRIM_ELE = 1713844692; const static int M = 2013265921; const static int REV_2 = 1006632961; IntRep* qmul(const IntRep* x, const IntRep* y, IntRep* z) { // compute the bit length of operand. int chunk_size = 8 * sizeof(x->s[0]); int max_len = (x->len < y->len) ? y->len : x->len; int N = max_len * chunk_size; // decide the parameters L and K /* THIS IS THE TABLE for CHOOSING L, given N \begin{center} \begin{tabular}{r | r | r | r } \hline $L$ & $\lg(K)$ & $K$ & $N = L * K$ \\ \hline\hline 1 & 26 & 67,108,864 & 67,108,864 \\ % *NO PRACTICAL USE (N(1) = N(2)) 2 & 25 & 33,554,432 & 67,108,864 \\ 3 & 23 & 8,388,608 & 25,165,824 \\ 4 & 21 & 2,097,152 & 8,388,608 \\ 5 & 19 & 524,288 & 2,621,440 \\ 6 & 17 & 131072 & 786432 \\ 7 & 15 & 32768 & 229376 \\ 8 & 13 & 8192 & 65536 \\ 9 & 11 & 2048 & 18432 \\ 10 & 9 & 512 & 5120 \\ 11 & 7 & 128 & 1408 \\ 12 & 5 & 32 & 384 \\ 13 & 3 & 8 & 104 \\ 14 & 1 & 2 & 28 \\ % *NOT PRACTICAL \hline \end{tabular} \end{center} */ // We only choose L to be 8, 4 or 2 int L; if (N < 65536) L = 8; else if (N < 8388608) L = 4; else if (N < 67,108,864) L = 2; else {printf("N is too big! \n"); exit(1);} double chunks = ceil((double)N / L); // minimum number of chunks needed assert (chunks <= INT_MAX); int K = 1 << ((int)ceil(log(chunks)/log(2.0))); //minimum power of 2 > chunks assert (2 * L + 1 + log(K)/log(2) <= lg(M) / log(2)); assert (K * L >= N); #ifdef DEBUG cout << "L = " << L << ", K = " << K << endl; #endif int K2 = 2 * K; // convert basis of vector int *a = new int[K2]; convert_base8(a, K2, x->s, x->len, L); int *b = new int[K2]; convert_base8(b, K2, y->s, y->len, L); // DFT int *A = new int[K2]; int *B = new int[K2]; int *R = new int[K2]; int *r = new int[K2]; int len_prod = K2 + (32 + L) / L; int *prod = new int[len_prod]; // an extra word for carries. NEW_DFT(a, K2, A, false); // DFT transform on a NEW_DFT(b, K2, B, false); // DFT transform on b comp_mult(A, B, R, K2); // component-wise multiplication NEW_DFT(R, K2, r, true); // reverse DFT transform post_proc(r, K2, prod, len_prod, L); // post-precessing (carries, etc). IntRep *p = store_back(prod, len_prod, z, L); // put the result back into an IntRep. p->sgn = (x->sgn == y->sgn) ? 1 : 0; Icheck(p); return p; } #ifdef REC_FFT void NEW_DFT(int *src, int len, int* dst, bool reverse) { // compute the appropriate primitive element: int s = (1 << 27) / len; // s shall be a power of 2 as well. int p = 1; long long int init_pe; if (! reverse) init_pe = PRIM_ELE; else init_pe = REV_PRIM_ELE; for (int i=0; p < s; i++) { init_pe = (init_pe * init_pe) % M; p <<= 1; } #ifdef DEBUG cout << "Primitive element is " << init_pe << endl; for (int i=0; i < len; i++) { cout << "input[" << i << "] = "<< src[i] << endl; } #endif g_src = src; // call the recursive version of DFT DFT_REC(src, len, 1, init_pe, dst); if (reverse) { // in reverse DFT, we need to divide the result vector by len. if (len != 1) { assert(len > 1); // we shall compute the reverse of len in the field Z_M. int p = 2; long long int rev_len = REV_2; for (int j=0; p < len; j++) { rev_len = (rev_len * REV_2) % M; p <<= 1; #ifdef DEBUG cout << "p = " << p << ", rev_len = " << rev_len << ", mod = " << (rev_len * p) % M << endl; #endif } assert((len * rev_len) % M == 1); // divide the vector dst by len for (int i=0; i < len; i++) { #ifdef DEBUG assert((dst[i] % len) == 0); #endif // dst[i] /= len; // *= inverse[lg(len)]; // REWRITE! dst[i] = (rev_len * dst[i]) % M; } } } } #else // Iterative FFT, more efficient. void NEW_DFT(int *src, int len, int* dst, bool reverse) { bit_reverse_copy(src, dst, len); // compute lg(len) int lg_len = -1; int t = len; while (t != 0){ lg_len++; t >>= 1; } long long int init_pe; // iterations from 1 (the bottom level) to lg(len) (the top level). for (int s = 1; s <= lg_len; s++) { int m = (1 << s); // size of sub-vector in this iteration. // compute the appropriate primitive element for this iteration: int t = 1 << (27 - s); // t shall be a power of 2 as well. int p = 1; if (! reverse) init_pe = PRIM_ELE; else init_pe = REV_PRIM_ELE; for (int i=0; p < t; i++) { init_pe = (init_pe * init_pe) % M; p <<= 1; } register long long int omega = 1; register long long int u; register long long int v; int m_half = m / 2; int *dst_half = dst + m_half; for (register int j=0; j < m / 2; j++) { for (register int k = j; k < len; k += m) { v = (dst_half[k] * omega) % M; u = dst[k] % M; dst[k] = (u + v) % M; dst_half[k] = (u - v + M) % M; } omega = (omega * init_pe) % M; } } if (reverse) { // in reverse DFT, we need to divide the result vector by len. if (len != 1) { assert(len > 1); // we shall compute the reverse of len in the field Z_M. int p = 2; long long int rev_len = REV_2; for (int j=0; p < len; j++) { rev_len = (rev_len * REV_2) % M; p <<= 1; } #ifdef DEBUG assert((len * rev_len) % M == 1); #endif // divide the vector dst by len for (int i=0; i < len; i++) { dst[i] = (rev_len * dst[i]) % M; } } } } #endif void bit_reverse_copy(int *a, int *A, int len) { int num_bits = 0; int t = len; while (t != 0){ num_bits++; t >>= 1; } for (unsigned int k=0; k < len; k++) { // compute bit-reverse of k int rev = 0; int p = 1; int q = (1 << (num_bits-2)); for (int j = 1; j <= num_bits; j++) { if (p & k) { // j-th bit is 1 rev = rev | q; } p <<= 1; q >>= 1; } // cout << "k = " << k << ", rev = " << rev << endl; A[rev] = a[k]; } } // bit_reverse_copy void DFT_REC(int *src, int len, int inc, int pe, int *dst) { if (len == 1) { // the base case dst[0] = src[0]; #ifdef DEBUG cout << "The leaf is " << src[0] << ", " << src - g_src << endl; #endif return; } //make recursive calls, the result array is used repeatedly in place. long long int pe2 = pe; pe2 = (pe2 * pe2) % M; // int *y0 = new int[len / 2]; // int *y1 = new int[len / 2]; DFT_REC(src, len / 2, inc * 2, pe2, dst); DFT_REC(src + inc, len / 2, inc * 2, pe2, dst + len / 2); #ifdef DEBUG for (int i=0; i < len / 2; i++) { cout << "y0[" << i << "] = " << y0[i] << ", y1[" << i << "] = " << y1[i] << endl; } cout << endl; #endif // sum up: long long int omega = 1; for (int k = 0; k < len / 2; k++) { long long int d1 = dst[k]; // y0[k]; long long int d2 = dst[k + len / 2]; // y1[k]; d2 = (d2 * omega) % M; dst[k] = (d1 + d2) % M; dst[k + len / 2 ] = (d1 - d2 + M) % M; omega = (omega * pe) % M; } // delete y0; // delete y1; return; } // DFT_REC void comp_mult(int *A, int *B, int *R, int len) { // component-wise modular multiplication of two arrays of integer (A and B). for (int i=0; i= m) { // more than 8 bits; z[index+1] += (z[index] >> L); #ifdef DEBUG assert(z[index+1] < m); #endif z[index] = z[index] % m; } } // add_and_carry void post_proc(int *r, int lr, int *prod, int lp, int L) { // initialize prod. for (int i = 0; i < lp; i++) { prod[i] = 0; } #ifdef DEBUG Integer sum = 0; cout << "len = " << lr << endl; for (int i = lr-1; i >= 0; i--) { sum = sum << L; sum = sum + r[i]; } cout << "post sum = " << sum << endl; #endif // discard the leading zeroes in r int lead_zeroes = 0; for (int i=lr-1; i>=0; i--) { if (r[i] == 0) lead_zeroes ++; else break; } int len = lr - lead_zeroes; // add together for (int i=0; i < len; i++) { if (L == 8) { int b1 = r[i] & 0x000000ff; int b2 = (r[i] & 0x0000ff00) >> 8; int b3 = (r[i] & 0x00ff0000) >> 16; int b4 = (r[i] & 0xff000000) >> 24; add_and_carry(b1, i, prod, L); add_and_carry(b2, i+1, prod, L); add_and_carry(b3, i+2, prod, L); add_and_carry(b4, i+3, prod, L); } else if (L == 4) { int b1 = r[i] & 0x0000000f; int b2 = (r[i] & 0x000000f0) >> 4; int b3 = (r[i] & 0x00000f00) >> 8; int b4 = (r[i] & 0x0000f000) >> 12; int b5 = (r[i] & 0x000f0000) >> 16; int b6 = (r[i] & 0x00f00000) >> 20; int b7 = (r[i] & 0x0f000000) >> 24; int b8 = (r[i] & 0xf0000000) >> 28; add_and_carry(b1, i, prod, L); add_and_carry(b2, i+1, prod, L); add_and_carry(b3, i+2, prod, L); add_and_carry(b4, i+3, prod, L); add_and_carry(b5, i+4, prod, L); add_and_carry(b6, i+5, prod, L); add_and_carry(b7, i+6, prod, L); add_and_carry(b8, i+7, prod, L); } } } // post_proc // store_back converts *r back to IntRep *z which is old GNU's Integer representation IntRep* store_back(int *r, int len, IntRep *z, int L) { int lead_zeroes = 0; for (int i=len-1; i>=0; i--) { if (r[i] == 0) lead_zeroes ++; else break; } len -= lead_zeroes; #ifdef DEBUG Integer sum = 0; cout << "len = " << len << endl; for (int i = len-1; i >= 0; i--) { sum = sum << L; sum = sum + r[i]; } cout << "sum = " << sum << ", L = " << L << endl; #endif int tt = 8 * sizeof(unsigned short int) / L; int len2 = (len + tt - 1) / tt; IntRep *p = Inew(len2); for (int k=0; k < p->sz; k++) p->s[k] = 0; p->len = len2; p->sgn = 1; for (int i = 0; i < p->len; i++) { if (L == 8) { p->s[i] = r[2*i] + (r[2*i+1] << 8); } else { p->s[i] = r[4*i] + (r[4*i+1] << 4) + (r[4*i+2] << 8) + (r[4*i+3] << 12); } #ifdef DEBUG cout << "p->s[" << i << "] = " << p->s[i] << endl; #endif } return p; } // post_proc // convert an IntRep representation to a int array void convert_base8(int *dst, int ld, const unsigned short int *src, int ls, int L) { assert(ld >= ((8 * sizeof(unsigned short int) + L) / L) * ls); for (int i = 0; i < ld; i++) dst[i] = 0; for (int i = 0; i < ls; i++) { if (L == 8) { dst[2*i] = src[i] & 0x000000ff; dst[2*i+1] = (src[i] & 0x0000ff00) >> 8; } else if (L == 4) { dst[4*i] = src[i] & 0x0000000f; dst[4*i+1] = (src[i] & 0x000000f0) >> 4; dst[4*i+2] = (src[i] & 0x00000f00) >> 8; dst[4*i+3] = (src[i] & 0x0000f000) >> 12; } else { cerr << "Error in " << __FILE__ << " at line: " << __LINE__ << ": the base must be 4 or 8" << endl; } } #ifdef DEBUG Integer sum = 0; cout << "len = " << ld << endl; for (int i = ld-1; i >= 0; i--) { sum = sum << L; sum = sum + dst[i]; } cout << "readin sum = " << sum << endl; #endif return; } // convert_base8 // CORE_randomize_rep creates // a randomized big integer IntRep *CORE_randomize_rep(int bit_length) { long int num_shorts = (bit_length + 8 * sizeof(unsigned short int) - 1) / (8 * sizeof(unsigned short int)); IntRep *r = Inew(num_shorts); r->sgn = 1; r->len = num_shorts; for (int k = 0; k < r->len; k++) { r->s[k] = rand() % USHRT_MAX; } int q = bit_length % (8 * sizeof(unsigned short int)); if (q != 0) { unsigned int m = (1 << (q + 1)); r->s[r->len - 1] = r->s[r->len - 1] % m; } Icheck(r); return r; } // CORE_randomize_rep