/* NN.C - natural numbers routines
*/

/* Copyright (C) RSA Laboratories, a division of RSA Data Security,
Inc., created 1991. All rights reserved.
*/

#include "global.h"
#include "rsaref.h"
#include "r_random.h"
#include "nn.h"

static inline void NN_DigitMult(NN_DIGIT a[2], NN_DIGIT b, NN_DIGIT c)
{
#ifdef _WIN32
	__int64 t = (__int64)b*(__int64)c;
	a[0]=(NN_DIGIT) (t&0xFFFFFFFF);
	a[1]=(NN_DIGIT) (t>>32);
#else
	NN_DIGIT t, u;
	NN_HALF_DIGIT bHigh, bLow, cHigh, cLow;

	bHigh = (NN_HALF_DIGIT)HIGH_HALF (b);
	bLow = (NN_HALF_DIGIT)LOW_HALF (b);
	cHigh = (NN_HALF_DIGIT)HIGH_HALF (c);
	cLow = (NN_HALF_DIGIT)LOW_HALF (c);

	a[0] = (NN_DIGIT)bLow * (NN_DIGIT)cLow;
	t = (NN_DIGIT)bLow * (NN_DIGIT)cHigh;
	u = (NN_DIGIT)bHigh * (NN_DIGIT)cLow;
	a[1] = (NN_DIGIT)bHigh * (NN_DIGIT)cHigh;

	if ((t += u) < u)
		a[1] += TO_HIGH_HALF (1);
	u = TO_HIGH_HALF (t);

	if ((a[0] += u) < u)
		a[1]++;
	a[1] += HIGH_HALF (t);
#endif
}

/* Sets a = b / c, where a and c are digits.

Lengths: b[2].
Assumes b[1] < c and HIGH_HALF (c) > 0. For efficiency, c should be
normalized.
*/
static inline void NN_DigitDiv(NN_DIGIT *a, NN_DIGIT b[2], NN_DIGIT c)
{
#ifdef _WIN32
	unsigned __int64 t=(((unsigned __int64)b[1])<<32) + (unsigned __int64)b[0];
	t/=(unsigned __int64) c;
	if (t > 0xFFFFFFFF) t=0xFFFFFFFF;
	*a=(NN_DIGIT)t;
#else
	NN_DIGIT t[2], u, v;
	NN_HALF_DIGIT aHigh, aLow, cHigh, cLow;

	cHigh = (NN_HALF_DIGIT)HIGH_HALF (c);
	cLow = (NN_HALF_DIGIT)LOW_HALF (c);

	t[0] = b[0];
	t[1] = b[1];

	/* Underestimate high half of quotient and subtract.
	*/
	if (cHigh == MAX_NN_HALF_DIGIT)
		aHigh = (NN_HALF_DIGIT)HIGH_HALF (t[1]);
	else
		aHigh = (NN_HALF_DIGIT)(t[1] / (cHigh + 1));
	u = (NN_DIGIT)aHigh * (NN_DIGIT)cLow;
	v = (NN_DIGIT)aHigh * (NN_DIGIT)cHigh;
	if ((t[0] -= TO_HIGH_HALF (u)) > (MAX_NN_DIGIT - TO_HIGH_HALF (u)))
		t[1]--;
	t[1] -= HIGH_HALF (u);
	t[1] -= v;

	/* Correct estimate.
	*/
	while ((t[1] > cHigh) ||
		((t[1] == cHigh) && (t[0] >= TO_HIGH_HALF (cLow)))) {
			if ((t[0] -= TO_HIGH_HALF (cLow)) > MAX_NN_DIGIT - TO_HIGH_HALF (cLow))
				t[1]--;
			t[1] -= cHigh;
			aHigh++;
		};

		/* Underestimate low half of quotient and subtract.
		*/
		if (cHigh == MAX_NN_HALF_DIGIT)
			aLow = (NN_HALF_DIGIT)LOW_HALF (t[1]);
		else
			aLow =
			(NN_HALF_DIGIT)((TO_HIGH_HALF (t[1]) + HIGH_HALF (t[0])) / (cHigh + 1));
		u = (NN_DIGIT)aLow * (NN_DIGIT)cLow;
		v = (NN_DIGIT)aLow * (NN_DIGIT)cHigh;
		if ((t[0] -= u) > (MAX_NN_DIGIT - u))
			t[1]--;
		if ((t[0] -= TO_HIGH_HALF (v)) > (MAX_NN_DIGIT - TO_HIGH_HALF (v)))
			t[1]--;
		t[1] -= HIGH_HALF (v);

		/* Correct estimate.
		*/
		while ((t[1] > 0) || ((t[1] == 0) && t[0] >= c)) {
			if ((t[0] -= c) > (MAX_NN_DIGIT - c))
				t[1]--;
			aLow++;
		};

		*a = TO_HIGH_HALF (aHigh) + aLow;
#endif
}

static inline NN_DIGIT NN_AddDigitMult(NN_DIGIT *, NN_DIGIT *, NN_DIGIT, NN_DIGIT *, unsigned int);
static inline NN_DIGIT NN_SubDigitMult(NN_DIGIT *, NN_DIGIT *, NN_DIGIT, NN_DIGIT *, unsigned int);
static inline unsigned int NN_DigitBits(NN_DIGIT);

/* Decodes character string b into a, where character string is ordered
from most to least significant.

Lengths: a[digits], b[len].
Assumes b[i] = 0 for i < len - digits * NN_DIGIT_LEN. (Otherwise most
significant bytes are truncated.)
*/
void NN_Decode(NN_DIGIT *a, unsigned int digits, unsigned char *b, unsigned int len)
{
	NN_DIGIT t;
	int j;
	unsigned int i, u;

	for (i = 0, j = len - 1; i < digits && j >= 0; i++) {
		t = 0;
		for (u = 0; j >= 0 && u < NN_DIGIT_BITS; j--, u += 8)
			t |= ((NN_DIGIT)b[j]) << u;
		a[i] = t;
	};

	for (; i < digits; i++)
		a[i] = 0;
}

/* Encodes b into character string a, where character string is ordered
from most to least significant.

Lengths: a[len], b[digits].
Assumes NN_Bits (b, digits) <= 8 * len. (Otherwise most significant
digits are truncated.)
*/
void NN_Encode(unsigned char *a, unsigned int len, NN_DIGIT *b, unsigned int digits)
{
	NN_DIGIT t;
	int j;
	unsigned int i, u;

	for (i = 0, j = len - 1; i < digits && j >= 0; i++) {
		t = b[i];
		for (u = 0; j >= 0 && u < NN_DIGIT_BITS; j--, u += 8)
			a[j] = (unsigned char)((t >> u)&0xff); //HACK md5chap
	};

	for (; j >= 0; j--) a[j] = 0;
}

/* Assigns a = b.

Lengths: a[digits], b[digits].
*/
void NN_Assign (NN_DIGIT *a, NN_DIGIT *b, unsigned int digits)
{
#ifdef _WIN32
	memcpy((void*)a,(void*)b,sizeof(NN_DIGIT)*digits);
#else
	unsigned int i;
	for (i = 0; i < digits; i++) a[i]=b[i];
#endif
}

/* Assigns a = 0.

Lengths: a[digits].
*/
void NN_AssignZero(NN_DIGIT *a, unsigned int digits)
{
#ifndef _WIN32
	unsigned int i;
	for (i = 0; i < digits; i++) a[i]=0;
#else
	memset((void*)a,0,digits*sizeof(NN_DIGIT));
#endif
}

/* Assigns a = 2^b.

Lengths: a[digits].
Requires b < digits * NN_DIGIT_BITS.
*/
void NN_Assign2Exp(NN_DIGIT *a, unsigned int b, unsigned int digits)
{
	NN_AssignZero (a, digits);
	if (b >= digits * NN_DIGIT_BITS) return;
	a[b / NN_DIGIT_BITS] = (NN_DIGIT)1 << (b % NN_DIGIT_BITS);
}

/* Computes a = b + c. Returns carry.
Lengths: a[digits], b[digits], c[digits].
*/
NN_DIGIT NN_Add(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT *c, unsigned int digits)
{
	NN_DIGIT carry=0;
	unsigned int i=digits;

	while (i--) {
		NN_DIGIT ai;
		if ((ai = *b++ + carry) < carry) ai = *c++;
		else {
			carry = ((ai += *c) < *c);
			c++;
		};
		*a++ = ai;
	};

	return (carry);
}

/* Computes a = b - c. Returns borrow.

Lengths: a[digits], b[digits], c[digits].
*/
NN_DIGIT NN_Sub(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT *c, unsigned int digits)
{
	NN_DIGIT borrow=0;
	unsigned int i=digits;

	while (i--) {
		NN_DIGIT ai;
		if ((ai = *b++ - borrow) > (MAX_NN_DIGIT - borrow)) ai = MAX_NN_DIGIT - *c++;
		else {
			borrow = ((ai -= *c) > (MAX_NN_DIGIT - *c));
			c++;
		};
		*a++ = ai;
	};

	return (borrow);
}

/* Computes a = b * c.

Lengths: a[2*digits], b[digits], c[digits].
Assumes digits < MAX_NN_DIGITS.
*/
void NN_Mult(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT *c, unsigned int digits)
{
	NN_DIGIT t[2*MAX_NN_DIGITS];
	NN_DIGIT *pt=t;
	unsigned int bDigits, cDigits;

	NN_AssignZero(t, 2 * digits);

	bDigits = NN_Digits(b, digits);
	cDigits = NN_Digits(c, digits);

	while (bDigits--) {
		pt[cDigits] += NN_AddDigitMult(pt, pt, *b++, c, cDigits);
		pt++;
	};

	NN_Assign (a, t, 2 * digits);

	//Zeroize potentially sensitive information.
	R_memset((POINTER)t, 0, sizeof (t));
}

/* Computes a = b * 2^c (i.e., shifts left c bits), returning carry.
Lengths: a[digits], b[digits].
Requires c < NN_DIGIT_BITS.
*/
NN_DIGIT NN_LShift(NN_DIGIT *a, NN_DIGIT *b,unsigned int c, unsigned int digits)
{
	NN_DIGIT carry=0;
	unsigned int i=digits, t=NN_DIGIT_BITS - c;

	if (t <= 0) return (0);
	if (c) {
		while (i--) {
			NN_DIGIT bi = *b++;
			*a++ = (bi << c) | carry;
			carry = (bi >> t);
		};
	}
	else {
		while (i--) {
			*a++ = *b++;
		};
	};
	return (carry);
}

/* Computes a = c div 2^c (i.e., shifts right c bits), returning carry.
Lengths: a[digits], b[digits].
Requires: c < NN_DIGIT_BITS.
*/
NN_DIGIT NN_RShift(NN_DIGIT *a, NN_DIGIT *b, unsigned int c, unsigned int digits)
{
	NN_DIGIT carry=0;
	unsigned int i=digits;
	unsigned int t=NN_DIGIT_BITS-c;
	if (t <= 0) return (0);
	if (c) {
		b+=i-1;
		a+=i-1;
		while (i--) {
			NN_DIGIT bi=*b--;
			*a-- = (bi >> c) | carry;
			carry = (bi << t);
		};
	}
	else {
		while (i--) {
			*a++ = *b++;
		};
	};
	return (carry);
}

/* Computes a = c div d and b = c mod d.

Lengths: a[cDigits], b[dDigits], c[cDigits], d[dDigits].
Assumes d > 0, cDigits < 2 * MAX_NN_DIGITS,
dDigits < MAX_NN_DIGITS.
*/
void NN_Div(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT *c, unsigned int cDigits, NN_DIGIT *d, unsigned int dDigits)
{
	NN_DIGIT ai, cc[2*MAX_NN_DIGITS+1], dd[MAX_NN_DIGITS], t;
	int i;
	unsigned int ddDigits, shift;

	ddDigits = NN_Digits (d, dDigits);
	if (ddDigits == 0) return;

	//Normalize operands.
	shift = NN_DIGIT_BITS - NN_DigitBits(d[ddDigits-1]);
	NN_AssignZero (cc, ddDigits);
	cc[cDigits] = NN_LShift (cc, c, shift, cDigits);
	NN_LShift (dd, d, shift, ddDigits);
	t = dd[ddDigits-1];

	NN_AssignZero(a, cDigits);

	for (i = cDigits-ddDigits; i >= 0; i--) {
		// Underestimate quotient digit and subtract.
		if (t == MAX_NN_DIGIT) ai = cc[i+ddDigits];
		else NN_DigitDiv(&ai, &cc[i+ddDigits-1], t + 1);
		cc[i+ddDigits] -= NN_SubDigitMult(&cc[i], &cc[i], ai, dd, ddDigits);

		//Correct estimate.
		while (cc[i+ddDigits] || (NN_Cmp (&cc[i], dd, ddDigits) >= 0)) {
			ai++;
			cc[i+ddDigits] -= NN_Sub (&cc[i], &cc[i], dd, ddDigits);
		};
		a[i] = ai;
	};

	//Restore result.
	NN_AssignZero (b, dDigits);
	NN_RShift(b, cc, shift, ddDigits);

	//Zeroize potentially sensitive information.
	R_memset ((POINTER)cc, 0, sizeof (cc));
	R_memset ((POINTER)dd, 0, sizeof (dd));
}

/* Computes a = b mod c.
Lengths: a[cDigits], b[bDigits], c[cDigits].
Assumes c > 0, bDigits < 2 * MAX_NN_DIGITS, cDigits < MAX_NN_DIGITS.
*/
void NN_Mod(NN_DIGIT *a, NN_DIGIT *b, unsigned int bDigits, NN_DIGIT *c, unsigned int cDigits)
{
	static NN_DIGIT t[2 * MAX_NN_DIGITS];
	NN_Div(t, a, b, bDigits, c, cDigits);

	//Zeroize potentially sensitive information.
	R_memset ((POINTER)t, 0, sizeof (t));
}

/* Computes a = b * c mod d.
Lengths: a[digits], b[digits], c[digits], d[digits].
Assumes d > 0, digits < MAX_NN_DIGITS.
*/
void NN_ModMult(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT *c, NN_DIGIT *d, unsigned int digits)
{
	NN_DIGIT t[2*MAX_NN_DIGITS];

	NN_Mult(t, b, c, digits);
	NN_Mod(a, t, 2 * digits, d, digits);

	//Zeroize potentially sensitive information.
	R_memset ((POINTER)t, 0, sizeof (t));
}

/* Computes a = b^c mod d.
Lengths: a[dDigits], b[dDigits], c[cDigits], d[dDigits].
Assumes d > 0, cDigits > 0, dDigits < MAX_NN_DIGITS.
*/
void NN_ModExp(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT *c, unsigned int cDigits, NN_DIGIT *d, unsigned int dDigits, R_RANDOM_STRUCT *rand)
{
	NN_DIGIT bPower[3][MAX_NN_DIGITS], ci, t[MAX_NN_DIGITS];
	NN_DIGIT blind[MAX_NN_DIGITS];
	NN_DIGIT blindinv[MAX_NN_DIGITS];
	NN_DIGIT blindinv4[MAX_NN_DIGITS];
	int i;
	unsigned int ciBits, j, s;

	if (rand) {
		NN_AssignZero(blind,MAX_NN_DIGITS);
		memset(blind, 0, sizeof(blind));//R_GenerateBytes((unsigned char*)blind,sizeof(blind),rand);
		NN_Mod(blind,blind,MAX_NN_DIGITS,d,dDigits);
		NN_ModInv(blindinv,blind,d,dDigits);

		NN_ModMult(blindinv4,blindinv,blindinv,d,dDigits);
		NN_ModMult(blindinv4,blindinv4,blindinv4,d,dDigits);

		NN_Assign(bPower[0], blind, dDigits);
		NN_Assign(t, blind, dDigits);
	}
	else {
		NN_ASSIGN_DIGIT(bPower[0], 1, dDigits);
		NN_ASSIGN_DIGIT(t, 1, dDigits);
	};

	//Store b, b^2 mod d, and b^3 mod d.
	NN_ModMult(bPower[0], bPower[0], b, d, dDigits);
	NN_ModMult(bPower[1], bPower[0], b, d, dDigits);
	NN_ModMult(bPower[2], bPower[1], b, d, dDigits);

	cDigits = NN_Digits (c, cDigits);
	for (i = cDigits - 1; i >= 0; i--) {
		ci = c[i];
		ciBits = NN_DIGIT_BITS;

		//Scan past leading zero bits of most significant digit.
		if (i == (int)(cDigits - 1)) {
			while (! DIGIT_2MSB (ci)) {
				ci <<= 2;
				ciBits -= 2;
			};
		};

		for (j = 0; j < ciBits; j += 2, ci <<= 2) {
			//Compute t = t^4 * b^s mod d, where s = two MSB's of ci.
			NN_ModMult(t, t, t, d, dDigits);
			NN_ModMult(t, t, t, d, dDigits);
			if ( (s = DIGIT_2MSB(ci)) != 0 ) {
				NN_ModMult(t, t, bPower[s-1], d, dDigits);
			}
			else {
				if (rand) {
					NN_ModMult(t, t, blind, d, dDigits);
				};
			};
			if (rand) {
				NN_ModMult(t, t, blindinv4, d, dDigits);
			};
		};
	};

	if (rand) {
		NN_ModMult(t, t, blindinv, d, dDigits);
	};

	NN_Assign (a, t, dDigits);

	/* Zeroize potentially sensitive information.
	*/
	R_memset((POINTER)bPower, 0, sizeof (bPower));
	R_memset((POINTER)t, 0, sizeof (t));
	R_memset((POINTER)blind, 0, sizeof(blind));
	R_memset((POINTER)blindinv, 0, sizeof(blind));
	R_memset((POINTER)blindinv4, 0, sizeof(blind));
}

/* Compute a = 1/b mod c, assuming inverse exists.

Lengths: a[digits], b[digits], c[digits].
Assumes gcd (b, c) = 1, digits < MAX_NN_DIGITS.
*/
void NN_ModInv(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT *c, unsigned int digits)
{
	NN_DIGIT q[MAX_NN_DIGITS], t1[MAX_NN_DIGITS], t3[MAX_NN_DIGITS],
		u1[MAX_NN_DIGITS], u3[MAX_NN_DIGITS], v1[MAX_NN_DIGITS],
		v3[MAX_NN_DIGITS], w[2*MAX_NN_DIGITS];
	int u1Sign;

	//Apply extended Euclidean algorithm, modified to avoid negative numbers.
	NN_ASSIGN_DIGIT (u1, 1, digits);
	NN_AssignZero (v1, digits);
	NN_Assign (u3, b, digits);
	NN_Assign (v3, c, digits);
	u1Sign = 1;

	while (!NN_Zero(v3, digits)) {
		NN_Div (q, t3, u3, digits, v3, digits);
		NN_Mult (w, q, v1, digits);
		NN_Add (t1, u1, w, digits);
		NN_Assign(u1, v1, digits);
		NN_Assign(v1, t1, digits);
		NN_Assign(u3, v3, digits);
		NN_Assign(v3, t3, digits);
		u1Sign = -u1Sign;
	};

	//Negate result if sign is negative.
	if (u1Sign < 0) NN_Sub(a, c, u1, digits);
	else NN_Assign (a, u1, digits);

	//Zeroize potentially sensitive information.
	R_memset((POINTER)q, 0, sizeof (q));
	R_memset((POINTER)t1, 0, sizeof (t1));
	R_memset((POINTER)t3, 0, sizeof (t3));
	R_memset((POINTER)u1, 0, sizeof (u1));
	R_memset((POINTER)u3, 0, sizeof (u3));
	R_memset((POINTER)v1, 0, sizeof (v1));
	R_memset((POINTER)v3, 0, sizeof (v3));
	R_memset((POINTER)w, 0, sizeof (w));
}

/* Computes a = gcd(b, c).
Lengths: a[digits], b[digits], c[digits].
Assumes b > c, digits < MAX_NN_DIGITS.
*/
void NN_Gcd(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT *c, unsigned int digits)
{
	NN_DIGIT t[MAX_NN_DIGITS], u[MAX_NN_DIGITS], v[MAX_NN_DIGITS];

	NN_Assign (u, b, digits);
	NN_Assign (v, c, digits);

	while(! NN_Zero(v, digits)) {
		NN_Mod(t, u, digits, v, digits);
		NN_Assign (u, v, digits);
		NN_Assign (v, t, digits);
	};

	NN_Assign (a, u, digits);

	//Zeroize potentially sensitive information.
	R_memset((POINTER)t, 0, sizeof (t));
	R_memset((POINTER)u, 0, sizeof (u));
	R_memset((POINTER)v, 0, sizeof (v));
}

/* Returns sign of a - b.

Lengths: a[digits], b[digits].
*/
int NN_Cmp(NN_DIGIT *a, NN_DIGIT *b, unsigned int digits)
{
	int i;
	for (i = digits - 1; i >= 0; i--) {
		if (a[i] > b[i]) return (1);
		if (a[i] < b[i]) return (-1);
	};
	return (0);
}

/* Returns nonzero iff a is zero.
Lengths: a[digits].
*/
int NN_Zero(NN_DIGIT *a, unsigned int digits)
{
	unsigned int i=digits;
	while (i--) if (*a++) return 0;
	return (1);
}

/* Returns the significant length of a in bits.
Lengths: a[digits].
*/
unsigned int NN_Bits(NN_DIGIT *a, unsigned int digits)
{
	if ((digits = NN_Digits(a, digits)) == 0) return (0);
	return ((digits - 1) * NN_DIGIT_BITS + NN_DigitBits (a[digits-1]));
}

/* Returns the significant length of a in digits.
Lengths: a[digits].
*/
unsigned int NN_Digits(NN_DIGIT *a, unsigned int digits)
{
	int i=digits;
	while (--i>=0 && !a[i]);
	return (i + 1);
}

/* Computes a = b + c*d, where c is a digit. Returns carry.

Lengths: a[digits], b[digits], d[digits].
*/
static inline NN_DIGIT NN_AddDigitMult(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT c, NN_DIGIT *d, unsigned int digits)
{
	NN_DIGIT carry=0, t[2];
	unsigned int i=digits;

	if (c == 0) return (0);
	while (i--) {
		NN_DigitMult (t, c, *d++);
		carry = ((*a = *b++ + carry) < carry);
		carry += ((*a++ += t[0]) < t[0]) + t[1];
	};
	return (carry);
}

/* Computes a = b - c*d, where c is a digit. Returns borrow.
Lengths: a[digits], b[digits], d[digits].
*/
static inline NN_DIGIT NN_SubDigitMult(NN_DIGIT *a, NN_DIGIT *b, NN_DIGIT c, NN_DIGIT *d, unsigned int digits)
{
	NN_DIGIT borrow=0, t[2];
	unsigned int i=digits;

	if (c == 0) return (0);

	while (i--) {
		NN_DigitMult (t, c, *d++);
		borrow = ((*a = *b++ - borrow) > (MAX_NN_DIGIT - borrow));
		borrow += ((*a++ -= t[0]) > (MAX_NN_DIGIT - t[0])) + t[1];
	};

	return (borrow);
}

// Returns the significant length of a in bits, where a is a digit.
static unsigned int NN_DigitBits(NN_DIGIT a)
{
	unsigned int i=0;

	while (a && i < NN_DIGIT_BITS) { i++; a>>=1; }

	return (i);
}
