for review: big ints

Bert Thompson aet at
Fri Apr 3 17:53:38 AEST 1998

Gday peoples,

Could someone please review this addition to the library.

P.S. The answer to your first question is "because it was there".

Estimated hours taken: 7

Implementation of an arbitrary precision integer type and
operations on it.


% Copyright (C) 1997-1998 The University of Melbourne.
% This file may only be copied under the terms of the GNU General
% Public License - see the file COPYING in the Mercury distribution.
% file: integer.m
% authors: aet Apr 1998. 
% Implements an arbitrary precision integer type and basic 
% operations on it.
% (For comments on possible improvements, see the end of this file.)

:- module integer.

:- interface.

:- import_module string.

:- type integer.

:- pred integer:'<'(integer, integer).
:- mode integer:'<'(in, in) is semidet.

:- pred integer:'>'(integer, integer).
:- mode integer:'>'(in, in) is semidet.

:- pred integer:'=<'(integer, integer).
:- mode integer:'=<'(in, in) is semidet.

:- pred integer:'>='(integer, integer).
:- mode integer:'>='(in, in) is semidet.

:- pred integer:'='(integer, integer).
:- mode integer:'='(in, in) is semidet.

:- func integer(int) = integer.

:- func integer:to_string(integer) = string.

:- func integer:from_string(string) = integer.

:- func integer:'+'(integer) = integer.

:- func integer:'-'(integer) = integer.

:- func integer:'+'(integer, integer) = integer.

:- func integer:'-'(integer, integer) = integer.

:- func integer:'*'(integer, integer) = integer.

:- func integer:'/'(integer, integer) = integer.

:- func integer:'rem'(integer, integer) = integer.

:- func integer:abs(integer) = integer.

:- pred integer:pow(integer, integer, integer).
:- mode integer:pow(in, in, out) is det.

% :- func integer:float(integer) = float.

:- implementation.

:- import_module require, list, char, std_util, int.

:- type sign == int.	% -1, 0, +1
:- type digit == int.	% base 10000 digit

:- type comparison
	--->	lessthan
	;	equal
	;	greaterthan.

	% Note: the list of digits is stored in reverse order.
	% That is, little end first.
:- type integer
	--->	i(sign, list(digit)).

	% We choose base=10000 since 10000^2+10000 < maxint.
	% XXX: We should check this. 
:- func base = int.
base = 10000.

integer:'<'(X1, X2) :-
	big_cmp(X1, X2) = lessthan.

integer:'>'(X1, X2) :-
	big_cmp(X1, X2) = greaterthan.

integer:'=<'(X1, X2) :-
	big_cmp(X1, X2) = C,
	( C = lessthan ; C = equal).

integer:'>='(X1, X2) :-
	big_cmp(X1, X2) = C,
	( C = greaterthan ; C = equal).

integer:'='(X1, X2) :-
	big_cmp(X1, X2) = equal.

:- func one = integer.
one = integer(1).

:- func zero = integer.
zero = integer(0).

integer:'+'(X1) =

integer:'-'(N) =

integer:'+'(X1,X2) =

integer:'-'(X1, X2) = 

integer:'*'(X1,X2) =

integer:'/'(X1, X2) =

integer:'rem'(X1, X2) =

integer:abs(i(_Sgn,Digits)) =

:- func big_neg(integer) = integer.
big_neg(i(S,Ds)) =

:- func big_mul(integer, integer) = integer.
big_mul(i(S1,Ds1), i(S2,Ds2)) =
	i(S,Ds) :-
	S = int:'*'(S1,S2),
	Ds = pos_mul(Ds1,Ds2).

:- func big_div(integer, integer) = integer.
big_div(X1,X2) =
	Q :-

:- func big_rem(integer, integer) = integer.
big_rem(X1,X2) =
	R :-

	% Compare two integers.
:- func big_cmp(integer, integer) = comparison.
big_cmp(i(S1,D1), i(S2,D2)) =
	( S1 < S2 ->
	; S1 > S2 ->
	; (S1=0, S2=0) ->
	; S1=1 ->

:- func pos_cmp(list(digit), list(digit)) = comparison.
pos_cmp(Xs,Ys) =
	pos_cmp_2(Xs1,Ys1) :-
	Xs1 = norm(Xs),
	Ys1 = norm(Ys).

:- func pos_cmp_2(list(digit), list(digit)) = comparison.
pos_cmp_2([],[]) = equal.
pos_cmp_2([_X|_Xs],[]) = greaterthan.
pos_cmp_2([],[_Y|_Ys]) = lessthan.
pos_cmp_2([X|Xs],[Y|Ys]) =
	Cmp :-
	Res = pos_cmp_2(Xs,Ys),
	( (Res = lessthan ; Res = greaterthan) ->
		Cmp = Res
	; X = Y ->
		Cmp = equal
	; X < Y ->
		Cmp = lessthan
		Cmp = greaterthan

:- func big_plus(integer, integer) = integer.
big_plus(i(S1,Ds1), i(S2,Ds2)) =
	Sum :-
	( S1 = S2 ->
		Sum = i(S1, pos_plus(Ds1,Ds2))
	; S1 = 1 ->
		C = pos_cmp(Ds1, Ds2),
		( C = lessthan ->
			Sum = i(-1, pos_sub(Ds2, Ds1))
		; C = greaterthan ->
			Sum = i(1, pos_sub(Ds1, Ds2))
			Sum = zero
		C = pos_cmp(Ds1,Ds2),
			C = lessthan ->
				Sum = i(1, pos_sub(Ds2, Ds1))
			; C = greaterthan ->
				Sum = i(-1, pos_sub(Ds1, Ds2))
				Sum = zero

integer__from_string(S) =
	Big :-
	string_to_integer(Cs) = Big.

:- func string_to_integer(list(char)) = integer.
string_to_integer(CCs) =
	Result :-
	( CCs = [] ->
		error("string_to_integer: unreachable"),
		Result = zero
	; CCs = [C|Cs] ->
		% Note: '-' must be in parentheses.
		( C = ('-') ->
			Result = big_neg(string_to_integer(Cs))
		; char__is_digit(C) ->
			Result = i(1,Digs),
			Digs = string_to_integer_acc(CCs,[])
			Result = zero,
			error("string_to_integer: can't parse string")
		Result = zero,
		error("string_to_integer: impossible value")

:- func string_to_integer_acc(list(char), list(digit)) = list(digit).
string_to_integer_acc([],Acc) = Acc.
string_to_integer_acc([C|Cs],Acc) =
	Result :-
	( char__is_digit(C) ->
		Dig = pos_int_to_digits(D1 - Z),
		NewAcc = pos_plus(Dig,mul_by_digit(10,Acc)),
		Result = string_to_integer_acc(Cs,NewAcc)
		error("integer:integer(string): can't parse string")

integer(N) =

:- func int_to_integer(int) = integer.
int_to_integer(D) =
	i(signum(D),pos_int_to_digits(AD)) :-

:- func signum(int) = int.
signum(N) =
	SN :-
	(N < 0 ->
		SN = -1
	; N = 0 ->
		SN = 0
		SN = 1

:- func pos_int_to_digits(int) = list(digit).
pos_int_to_digits(D) =
	Result :-
	( D = 0 ->
		Result = []
		Result = [ S1 | pos_int_to_digits(C1) ],

	% Multiply a list of digits by the base.
:- func mul_base(list(digit)) = list(digit).
mul_base(Xs) =

:- func mul_by_digit(digit, list(digit)) = list(digit).
mul_by_digit(D,Xs) =
	Norm :-
	Norm = norm(DXs),
	DXs = mul_by_digit_2(D,Xs).

:- func mul_by_digit_2(digit, list(digit)) = list(digit).
mul_by_digit_2(_D,[]) = [].
mul_by_digit_2(D, [X|Xs]) = [ D*X | mul_by_digit_2(D,Xs) ].

	% Normalise a list of ints so that each element of the list
	% is a base 10000 digit and there are no extraneous zeros
	% at the big end. (Note: the big end (most significant
	% digit) is at the end of the list.)
:- func norm(list(int)) = list(digit).
norm(Xs) =

:- func nuke_zeros(list(digit)) = list(digit).
nuke_zeros(Xs) =
	Zs :-
	RZs = drop_while(equals_zero,RXs),

:- func norm_2(list(int),digit) = list(digit).
norm_2([],C) =
	Xs :-
	( C = 0 ->
		Xs = []
		Xs = [C]
norm_2([X|Xs],C) =
	[S1 | norm_2(Xs,C1)] :-
	XC = X + C,

	% Chop an integer into the first two digits of its
	% base 10000 representation.
:- pred chop(int, digit, digit).
:- mode chop(in, out, out) is det.
chop(N,Dig,Carry) :-
	Dig = N mod base,
	Carry = N div base.

:- pred equals_zero(int).
:- mode equals_zero(in) is semidet.
equals_zero(X) :-
	X = 0.

:- func drop_while(pred(T),list(T)) = list(T).
:- mode drop_while(pred(in) is semidet,in) = out is det.
drop_while(_F,[]) = [].
drop_while(F,[X|Xs]) =
	( F(X) ->

:- func pos_plus(list(digit), list(digit)) = list(digit).
pos_plus(Xs, Ys) =
	Norm :-
	Norm = norm(Sums),
	Sums = add_pairs(Xs,Ys).

:- func pos_sub(list(digit), list(digit)) = list(digit).
pos_sub(Xs, Ys) =
	Norm :-
	Norm = norm(Diffs),
	Diffs = diff_pairs(Xs,Ys).

:- func add_pairs(list(int), list(int)) = list(int).
add_pairs(XXs, YYs) =
	XYs :-
	( XXs = [] ->
		XYs = YYs
	; YYs = [] ->
		XYs = XXs
	; (XXs = [X|Xs], YYs = [Y|Ys]) ->
		XYs = [ X+Y | add_pairs(Xs,Ys) ]
		error("add_pairs: ")

:- func diff_pairs(list(int), list(int)) = list(int).
diff_pairs(XXs, YYs) =
	XYs :-
	( XXs = [] ->
	; YYs = [] ->
		XYs = XXs
	; (XXs = [X|Xs], YYs = [Y|Ys]) ->
		XYs = [ X-Y | diff_pairs(Xs,Ys) ]
		error("diff_pairs: ")

:- pred int_negate(int, int).
:- mode int_negate(in,out) is det.
int_negate(M,NegM) :-
	NegM = -M.

:- func pos_mul(list(digit), list(digit)) = list(digit).
pos_mul([],_Ys) = [].
pos_mul([X|Xs], Ys) =
	Sum :-
	mul_by_digit(X,Ys) = XYs,
	pos_mul(Xs,Ys) = XsYs,
	mul_base(XsYs) = TenXsYs,
	Sum = pos_plus(XYs, TenXsYs).

integer:to_string(N) =
	S :-
	integer_to_string_2(N) = S.

:- func integer_to_string_2(integer) = string.
integer_to_string_2(i(S,Ds)) =
	Str :-
	( S = (-1) ->
		Sgn = "-"
		Sgn = ""

:- func digits_to_string(list(digit)) = string.
digits_to_string(DDs) =
	Str :-
	( Rev = [] ->
		Str = "0"
	; Rev = [R|Rs] ->
		error("digits_to_string: ")

:- pred digit_to_string(digit,string).
:- mode digit_to_string(in,out) is det.
digit_to_string(D,S) :-
	Width = 4,	% = log10(base)

:- pred big_div_rem(integer, integer, integer, integer).
:- mode big_div_rem(in, in, out, out) is det.
big_div_rem(N1, N2, Qt, Rm) :-
	( N2 = zero ->
		error("big_div_rem: division by zero")
	; N1 = zero ->
		Qt = zero,
		Rm = N2
		N1 = i(S1,D1),
		N2 = i(S2,D2),
		Qt = i(SQ,Q),
		Rm = i(SR,R),
		SR = S1,
		SQ = S1 * S2,
		Q = norm(QRR),
		R = norm(RRR),

	% Algorithm: We take digits from the start of U (call them Ur)
	% and divide by V to get a digit Q of the ratio.
	% Essentially the usual long division algorithm.
	% Qhat is an approximation to Q. It may be at most 2 too big.
	% If the first digit of V is less than base/2, then
	% we scale both the numerator and denominator. This
	% way, we can use Knuth's[*] nifty trick for finding
	% an accurate approximation to Q. That's all we use from
	% Knuth; his MIX algorithm is fugly.
	% [*] Knuth, Semi-numerical algorithms.
:- pred div_rev(list(digit), list(digit), list(digit), list(digit),
:- mode div_rev(in, in, in, out, out) is det.
div_rev(Ur,U,V,Qt,Rm) :-
	( V = [V0|_] ->
		( V0 < base div 2 ->
			Qt = Q,
			Rm = div_by_digit_rev(M,R),
			M = base div (V0+1)
		error("div_rev: ")

:- pred div_rev_2(list(digit), list(digit), list(digit), list(digit),
:- mode div_rev_2(in, in, in, out, out) is det.
div_rev_2(Ur,U,V,Qt,Rm) :-
	( pos_lt_rev(Ur,V) ->
		( U = [] ->
			Qt = [0],
			Rm = Ur
		; U = [Ua|Uas] ->
			Qt = [0|Quot],
			Rm = Rem,
			error("div_rev_2: software error")
		( U = [] ->
			Qt = [Q],
			Rm = NewUr
		; U = [Ua|Uas] ->
			Qt = [Q|Quot],
			Rm = Rem,
			error("div_rev_2: software error")
		NewUr = pos_sub_rev(Ur,mul_by_digit_rev(Q,V)),
		( pos_geq_rev(Ur,mul_by_digit_rev(Qhat,V)) ->
			Q = Qhat
		; pos_geq_rev(Ur,mul_by_digit_rev(Qhat-1,V)) ->
			Q = Qhat-1
			Q = Qhat - 2
		V0 = head(V),
		U0 = head(Ur),
		( length(Ur) > length(V) ->
			Qhat = (U0*B+U1) div V0,
			U1 = head(tail(Ur))
			Qhat = U0 div V0
		B = base

:- func length(list(T)) = int.
length([]) = 0.
length([_|Xs]) = 1 + length(Xs).

:- func head(list(T)) = T.
head(HT) =
	H :-
	( HT = [Hd|_T] ->
		H = Hd
		error("head: []")
:- func tail(list(T)) = list(T).
tail(HT) =
	T :-
	( HT = [_H|Tl] ->
		T = Tl
		error("tail: []")

	% Multiply a *reverse* list of digits (big end first)
	% by a digit. 
	% Note: All functions whose name has the suffix "_rev"
	% operate on such reverse lists of digits.
:- func mul_by_digit_rev(digit,list(digit)) = list(digit).
mul_by_digit_rev(D,Xs) =
	Rev :-
	Mul = mul_by_digit(D,RXs),

:- func div_by_digit_rev(digit,list(digit)) = list(digit).
div_by_digit_rev(_D,[]) = [].
div_by_digit_rev(D,[X|Xs]) = div_by_digit_rev_2(X,Xs,D).

:- func div_by_digit_rev_2(digit,list(digit),digit) = list(digit).
div_by_digit_rev_2(X,Xs,D) =
	[Q|Rest] :-
	Q = X div D,
	( Xs = [] ->
		Rest = []
	; Xs = [H|T] ->
		Rest = div_by_digit_rev_2(R*base + H,T,D),
		R = X rem D
		error("div_by_digit_rev_2: ")

:- func pos_sub_rev(list(digit), list(digit)) = list(digit).
pos_sub_rev(Xs,Ys) =
	Rev :-
	Sum = pos_sub(RXs,RYs),
	list__reverse(Sum, Rev).

:- pred pos_lt_rev(list(digit), list(digit)).
:- mode pos_lt_rev(in, in) is semidet.
pos_lt_rev(Xs,Ys) :-
	big_cmp(i(1,RXs),i(1,RYs)) = lessthan.

:- pred pos_geq_rev(list(digit), list(digit)).
:- mode pos_geq_rev(in, in) is semidet.
pos_geq_rev(Xs,Ys) :-
	C = big_cmp(i(1,RXs),i(1,RYs)),
	( C = greaterthan ; C = equal).

integer:pow(A,N,P) :-
	( N < integer(0) ->
		error("big_pow: negative exponent")
		P = big_pow(A,N)

:- func big_pow(integer, integer) = integer.
big_pow(A,N) =
	P :-
	( N = integer(0) ->
		P = integer(1)
	; big_odd(N) ->
		P = A * big_pow(A,N-integer(1))
	; % even
		P = big_sqr(big_pow(A, N/integer(2)))
:- func big_sqr(integer) = integer.
big_sqr(A) =
	A * A.

:- pred big_odd(integer).
:- mode big_odd(in) is semidet.
big_odd(N) :-
	( N = integer(0) ->
		N = i(_S,[D|_Ds]),
		D mod 2 = 1

% Possible improvements:
%	1) Allow negative digits (-base+1 .. base-1) in lists of
%	  digits and normalise only when printing. This would
%	  probably simplify the division algorithm, also.
%	2) Alternatively, instead of using base=10000, use *all* the
%	  bits in an int and make use of the properties of machine
%	  arithmetic. Base 10000 doesn't use even half the bits
%	  in an int, which is inefficient.
%	3) Use an O(n^(3/2)) algorithm for multiplying large
%	  integers, rather than the current O(n^2) method.
%	  There's an obvious divide-and-conquer technique.
%	  (Karatsuba(?) multiplication.)
%	4) Fourier methods and multiplication via modular arithmetic
%	  are left as exercises to the reader. 8^)
%	5) Use a double-ended list type rather than simple lists.
%	  This would avoid the need for the list reversals that
%	  are performed in the division algorithm.
%	Of the above, 1) and 5) would have the best bang-for-buck,
%	and 3) is trivial to implement.

