for review: deforestation [4/4]

Simon Taylor stayl at cs.mu.OZ.AU
Thu Feb 19 16:38:07 AEDT 1998


%-----------------------------------------------------------------------------%
% Copyright (C) 1998 University of Melbourne.
% This file may only be copied under the terms of the GNU General
% Public License - see the file COPYING in the Mercury distribution.
%-----------------------------------------------------------------------------%
% File: pd_info.m
% Main author: stayl
%
% Types for deforestation.
%-----------------------------------------------------------------------------%

:- module pd_info.

:- interface.

:- import_module pd_term, hlds_module, options, hlds_data, instmap.
:- import_module map, list, io, set, std_util, term, getopt.

:- type pd_info 
	---> pd_info(
		io__state,
		module_info,
		maybe(unfold_info),
		goal_version_index,
		version_index,
		pd_arg_info,
		int,			% version counter.
		global_term_info,
		set(pred_proc_id),
		int,			% current depth
		set(pred_proc_id),	% created versions
		set(pair(pred_proc_id)),% pairs of procedures which when
					% paired for deforestation produce
					% little improvement
		unit,
		unit
	).

		% map from list of called preds in the 
		% conjunctions to the specialised versions.
:- type goal_version_index == map(list(pred_proc_id), list(pred_proc_id)).

		% map from version id to the info about the version.
:- type version_index == map(pred_proc_id, version_info).

:- inst unique_pd_info = ground.

:- mode pd_info_di :: unique_pd_info -> dead.
:- mode pd_info_uo :: free -> unique_pd_info.
:- mode pd_info_ui :: unique_pd_info -> unique_pd_info.

:- inst pd_info_no_io = ground.
:- mode pd_info_set_io :: pd_info_no_io -> dead.

:- pred pd_info_init(module_info, pd_arg_info, io__state, pd_info).
:- mode pd_info_init(in, in, di, pd_info_uo) is det.

:- pred pd_info_init_unfold_info(pred_proc_id, 
		pred_info, proc_info, pd_info, pd_info).
:- mode pd_info_init_unfold_info(in, in, in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_io_state(io__state, pd_info, pd_info).
:- mode pd_info_get_io_state(uo, pd_info_di, out(pd_info_no_io)) is det.

:- pred pd_info_get_module_info(module_info, pd_info, pd_info).
:- mode pd_info_get_module_info(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_unfold_info(unfold_info, pd_info, pd_info).
:- mode pd_info_get_unfold_info(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_goal_version_index(goal_version_index, pd_info, pd_info).
:- mode pd_info_get_goal_version_index(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_versions(version_index, pd_info, pd_info).
:- mode pd_info_get_versions(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_proc_arg_info(pd_arg_info, pd_info, pd_info).
:- mode pd_info_get_proc_arg_info(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_counter(int, pd_info, pd_info).
:- mode pd_info_get_counter(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_global_term_info(global_term_info, pd_info, pd_info).
:- mode pd_info_get_global_term_info(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_parent_versions(set(pred_proc_id), pd_info, pd_info).
:- mode pd_info_get_parent_versions(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_depth(int, pd_info, pd_info).
:- mode pd_info_get_depth(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_created_versions(set(pred_proc_id), pd_info, pd_info).
:- mode pd_info_get_created_versions(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_useless_versions(set(pair(pred_proc_id)), pd_info, pd_info).
:- mode pd_info_get_useless_versions(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_io_state(io__state, pd_info, pd_info).
:- mode pd_info_set_io_state(di, pd_info_set_io, pd_info_uo) is det.

:- pred pd_info_set_module_info(module_info, pd_info, pd_info).
:- mode pd_info_set_module_info(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_unfold_info(unfold_info, pd_info, pd_info).
:- mode pd_info_set_unfold_info(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_goal_version_index(goal_version_index, pd_info, pd_info).
:- mode pd_info_set_goal_version_index(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_versions(version_index, pd_info, pd_info).
:- mode pd_info_set_versions(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_proc_arg_info(pd_arg_info, pd_info, pd_info).
:- mode pd_info_set_proc_arg_info(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_counter(int, pd_info, pd_info).
:- mode pd_info_set_counter(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_global_term_info(global_term_info, pd_info, pd_info).
:- mode pd_info_set_global_term_info(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_parent_versions(set(pred_proc_id), pd_info, pd_info).
:- mode pd_info_set_parent_versions(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_depth(int, pd_info, pd_info).
:- mode pd_info_set_depth(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_created_versions(set(pred_proc_id), pd_info, pd_info).
:- mode pd_info_set_created_versions(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_useless_versions(set(pair(pred_proc_id)), pd_info, pd_info).
:- mode pd_info_set_useless_versions(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_update_goal(hlds_goal, pd_info, pd_info).
:- mode pd_info_update_goal(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_lookup_option(option, option_data, pd_info, pd_info).
:- mode pd_info_lookup_option(in, out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_lookup_bool_option(option, bool, pd_info, pd_info).
:- mode pd_info_lookup_bool_option(in, out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_bind_var_to_functor(var, cons_id, pd_info, pd_info).
:- mode pd_info_bind_var_to_functor(in, in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_unset_unfold_info(pd_info, pd_info).
:- mode pd_info_unset_unfold_info(pd_info_di, pd_info_uo) is det.

	% With polymorphic modes these would be unnecessary.
:- pred pd_info_foldl(pred(T, pd_info, pd_info), list(T), pd_info, pd_info).
:- mode pd_info_foldl(pred(in, pd_info_di, pd_info_uo) is det, in,
		pd_info_di, pd_info_uo) is det.

:- pred pd_info_foldl2(pred(T, U, U, pd_info, pd_info), list(T), U, U, 
		pd_info, pd_info).
:- mode pd_info_foldl2(pred(in, in, out, pd_info_di, pd_info_uo) is det, in,
		in, out, pd_info_di, pd_info_uo) is det.

%-----------------------------------------------------------------------------%
:- implementation.

:- import_module hlds_pred, prog_data, pd_debug, pd_util, det_util, globals.
:- import_module inst_match, hlds_goal.
:- import_module assoc_list, bool, int, require, string.

pd_info_init(ModuleInfo, ProcArgInfos, IO, PdInfo) :-
	map__init(GoalVersionIndex),
	map__init(Versions),
	set__init(ParentVersions),
	pd_term__global_term_info_init(GlobalInfo),
	set__init(CreatedVersions),
	set__init(UselessVersions),
	PdInfo = pd_info(IO, ModuleInfo, no, GoalVersionIndex, Versions, 
		ProcArgInfos, 0, GlobalInfo, ParentVersions, 0, 
		CreatedVersions, UselessVersions, unit, unit).

pd_info_init_unfold_info(PredProcId, PredInfo, ProcInfo) -->
	pd_info_get_module_info(ModuleInfo),
	{
	proc_info_get_initial_instmap(ProcInfo, ModuleInfo, InstMap),
	CostDelta = 0,
	pd_term__local_term_info_init(LocalTermInfo),
	set__singleton_set(Parents, PredProcId),
	UnfoldInfo = unfold_info(ProcInfo, InstMap, CostDelta, LocalTermInfo, 
			PredInfo, Parents, PredProcId, no, 0, unit)
	},
	pd_info_set_unfold_info(UnfoldInfo).

pd_info_get_io_state(IO, PdInfo, PdInfo) :-
	PdInfo = pd_info(IO0, _,_,_,_,_,_,_,_,_,_,_,_,_),
	unsafe_promise_unique(IO0, IO).
pd_info_get_module_info(ModuleInfo, PdInfo, PdInfo) :-
	PdInfo = pd_info(_, ModuleInfo, _,__,_,_,_,_,_,_,_,_,_,_).
pd_info_get_unfold_info(UnfoldInfo, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_, MaybeUnfoldInfo, _,_,_,_,_,_,_,_,_,_,_),
	(
		MaybeUnfoldInfo = yes(UnfoldInfo)
	;
		MaybeUnfoldInfo = no,
		error("pd_info_get_unfold_info: unfold_info not set.")
	).
pd_info_get_goal_version_index(Index, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,Index,_,_,_,_,_,_,_,_,_,_).
pd_info_get_versions(Versions, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,_,Versions,_,_,_,_,_,_,_,_,_).
pd_info_get_proc_arg_info(ProcArgInfo, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,_,_,ProcArgInfo,_,_,_,_,_,_,_,_).
pd_info_get_counter(Counter, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,_,_,_,Counter,_,_,_,_,_,_,_).
pd_info_get_global_term_info(TermInfo, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,_,_,_,_,TermInfo,_,_,_,_,_,_).
pd_info_get_parent_versions(Parents, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,_,_,_,_,_,Parents,_,_,_,_,_).
pd_info_get_depth(Depth, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,_,_,_,_,_,_,Depth,_,_,_,_).
pd_info_get_created_versions(Versions, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,_,_,_,_,_,_,_,Versions,_,_,_).
pd_info_get_useless_versions(Versions, PdInfo, PdInfo) :-
	PdInfo = pd_info(_,_,_,_,_,_,_,_,_,_,_,Versions,_,_).

pd_info_set_io_state(IO0, pd_info(_, B,C,D,E,F,G,H,I,J,K,L,M,N), 
		pd_info(IO, B,C,D,E,F,G,H,I,J,K,L,M,N)) :-
	unsafe_promise_unique(IO0, IO).
pd_info_set_module_info(ModuleInfo, pd_info(A,_,C,D,E,F,G,H,I,J,K,L,M,N),
		pd_info(A, ModuleInfo, C,D,E,F,G,H,I,J,K,L,M,N)).
pd_info_set_unfold_info(UnfoldInfo, pd_info(A,B,_,D,E,F,G,H,I,J,K,L,M,N),
		pd_info(A,B, yes(UnfoldInfo), D,E,F,G,H,I,J,K,L,M,N)).
pd_info_unset_unfold_info(pd_info(A,B,_,D,E,F,G,H,I,J,K,L,M,N),
		pd_info(A,B, no, D,E,F,G,H,I,J,K,L,M,N)).
pd_info_set_goal_version_index(Index, pd_info(A,B,C,_,E,F,G,H,I,J,K,L,M,N),
		pd_info(A,B,C,Index,E,F,G,H,I,J,K,L,M,N)).
pd_info_set_versions(Versions, pd_info(A,B,C,D,_,F,G,H,I,J,K,L,M,N),
		pd_info(A,B,C,D,Versions,F,G,H,I,J,K,L,M,N)).
pd_info_set_proc_arg_info(ProcArgInfo, pd_info(A,B,C,D,E,_,G,H,I,J,K,L,M,N),
		pd_info(A,B,C,D,E,ProcArgInfo,G,H,I,J,K,L,M,N)).
pd_info_set_counter(Counter, pd_info(A,B,C,D,E,F,_,H,I,J,K,L,M,N),
		pd_info(A,B,C,D,E,F,Counter,H,I,J,K,L,M,N)).
pd_info_set_global_term_info(TermInfo, pd_info(A,B,C,D,E,F,G,_,I,J,K,L,M,N),
		pd_info(A,B,C,D,E,F,G,TermInfo,I,J,K,L,M,N)).
pd_info_set_parent_versions(Parents, pd_info(A,B,C,D,E,F,G,H,_,J,K,L,M,N),
		pd_info(A,B,C,D,E,F,G,H,Parents,J,K,L,M,N)).
pd_info_set_depth(Depth, pd_info(A,B,C,D,E,F,G,H,I,_,K,L,M,N),
		pd_info(A,B,C,D,E,F,G,H,I,Depth,K,L,M,N)).
pd_info_set_created_versions(Versions, pd_info(A,B,C,D,E,F,G,H,I,J,_,L,M,N),
		pd_info(A,B,C,D,E,F,G,H,I,J,Versions,L,M,N)).
pd_info_set_useless_versions(Versions, pd_info(A,B,C,D,E,F,G,H,I,J,K,_,M,N),
		pd_info(A,B,C,D,E,F,G,H,I,J,K,Versions,M,N)).

pd_info_update_goal(_ - GoalInfo) -->
	pd_info_get_instmap(InstMap0),
	{ goal_info_get_instmap_delta(GoalInfo, Delta) },
	{ instmap__apply_instmap_delta(InstMap0, Delta, InstMap) },
	pd_info_set_instmap(InstMap).

pd_info_lookup_option(Option, OptionData) -->
	pd_info_get_io_state(IO0),
	{ globals__io_lookup_option(Option, OptionData, IO0, IO) },
	pd_info_set_io_state(IO).

pd_info_lookup_bool_option(Option, Value) -->
	pd_info_lookup_option(Option, Value0),
	{ Value0 = bool(Value1) ->
		Value = Value1
	;
		error("pd_info_lookup_bool_option")
	}.

pd_info_bind_var_to_functor(Var, ConsId) -->
	pd_info_get_instmap(InstMap0),
	pd_info_get_module_info(ModuleInfo0),
	{ instmap__bind_var_to_functor(Var, ConsId, InstMap0, InstMap,
		ModuleInfo0, ModuleInfo) },
	pd_info_set_instmap(InstMap),
	pd_info_set_module_info(ModuleInfo).

pd_info_foldl(_, []) --> [].
pd_info_foldl(Pred, [H | T]) -->
	call(Pred, H),
	pd_info_foldl(Pred, T).

pd_info_foldl2(_, [], Acc, Acc) --> [].
pd_info_foldl2(Pred, [H | T], Acc0, Acc) -->
	call(Pred, H, Acc0, Acc1),
	pd_info_foldl2(Pred, T, Acc1, Acc).

%-----------------------------------------------------------------------------%

:- interface.

:- type unfold_info
	--->	unfold_info(
			proc_info,
			instmap,
			int,		% cost delta
			local_term_info,% local termination info
			pred_info,
			set(pred_proc_id),
			pred_proc_id,	% current pred_proc_id
			bool,		% does determinism analysis 
					% need to be run
			int,		% size delta
			unit
		).	

	% pd_arg_info records which procedures have arguments for which
	% it might be worthwhile to attempt deforestation if there
	% is extra information about them, and the branches of the single
	% branched goal in the top level conjunction which produce that extra
	% information. 
:- type pd_arg_info == map(pred_proc_id, pd_proc_arg_info).

:- type pd_proc_arg_info	==	pd_branch_info(int).

:- type pd_branch_info(T)
	--->	pd_branch_info(
			branch_info_map(T),
			set(T),		% variables for which we want 
					% extra left context
			set(T)		% outputs for which we have no 
					% extra information
		).

	% Vars for which there is extra information at the end
	% of some branches, and the branches which add the extra
	% information (numbered from 1).
:- type branch_info_map(T)	==	map(T, set(int)).

:- pred pd_info_get_proc_info(proc_info, pd_info, pd_info).
:- mode pd_info_get_proc_info(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_instmap(instmap, pd_info, pd_info).
:- mode pd_info_get_instmap(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_cost_delta(int, pd_info, pd_info).
:- mode pd_info_get_cost_delta(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_local_term_info(local_term_info, pd_info, pd_info).
:- mode pd_info_get_local_term_info(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_pred_info(pred_info, pd_info, pd_info).
:- mode pd_info_get_pred_info(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_parents(set(pred_proc_id), pd_info, pd_info).
:- mode pd_info_get_parents(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_pred_proc_id(pred_proc_id, pd_info, pd_info).
:- mode pd_info_get_pred_proc_id(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_changed(bool, pd_info, pd_info).
:- mode pd_info_get_changed(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_get_size_delta(int, pd_info, pd_info).
:- mode pd_info_get_size_delta(out, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_proc_info(proc_info, pd_info, pd_info).
:- mode pd_info_set_proc_info(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_instmap(instmap, pd_info, pd_info).
:- mode pd_info_set_instmap(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_cost_delta(int, pd_info, pd_info).
:- mode pd_info_set_cost_delta(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_local_term_info(local_term_info, pd_info, pd_info).
:- mode pd_info_set_local_term_info(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_pred_info(pred_info, pd_info, pd_info).
:- mode pd_info_set_pred_info(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_parents(set(pred_proc_id), pd_info, pd_info).
:- mode pd_info_set_parents(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_pred_proc_id(pred_proc_id, pd_info, pd_info).
:- mode pd_info_set_pred_proc_id(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_changed(bool, pd_info, pd_info).
:- mode pd_info_set_changed(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_set_size_delta(int, pd_info, pd_info).
:- mode pd_info_set_size_delta(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_incr_cost_delta(int, pd_info, pd_info).
:- mode pd_info_incr_cost_delta(in, pd_info_di, pd_info_uo) is det.

:- pred pd_info_incr_size_delta(int, pd_info, pd_info).
:- mode pd_info_incr_size_delta(in, pd_info_di, pd_info_uo) is det.

%-----------------------------------------------------------------------------%

:- implementation.

pd_info_get_proc_info(ProcInfo) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(ProcInfo, _,_,_,_,_,_,_,_,_) }.
pd_info_get_instmap(InstMap) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(_, InstMap, _,_,_,_,_,_,_,_) }.
pd_info_get_cost_delta(CostDelta) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(_,_, CostDelta, _,_,_,_,_,_,_) }.
pd_info_get_local_term_info(TermInfo) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(_,_,_,TermInfo,_,_,_,_,_,_) }.
pd_info_get_pred_info(PredInfo) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(_,_,_,_,PredInfo,_,_,_,_,_) }.
pd_info_get_parents(Parents) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(_,_,_,_,_,Parents,_,_,_,_) }.
pd_info_get_pred_proc_id(PredProcId) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(_,_,_,_,_,_,PredProcId,_,_,_) }.
pd_info_get_changed(Changed) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(_,_,_,_,_,_,_,Changed,_,_) }.
pd_info_get_size_delta(SizeDelta) -->
	pd_info_get_unfold_info(UnfoldInfo),
	{ UnfoldInfo = unfold_info(_,_,_,_,_,_,_,_,SizeDelta,_) }.
	

pd_info_set_proc_info(ProcInfo) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(_,B,C,D,E,F,G,H,I,J) },
	{ UnfoldInfo = unfold_info(ProcInfo, B,C,D,E,F,G,H,I,J) },
	pd_info_set_unfold_info(UnfoldInfo).
pd_info_set_instmap(InstMap) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(A,_,C,D,E,F,G,H,I,J) },
	{ UnfoldInfo = unfold_info(A, InstMap,C,D,E,F,G,H,I,J) },
	pd_info_set_unfold_info(UnfoldInfo).
pd_info_set_cost_delta(CostDelta) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(A,B,_,D,E,F,G,H,I,J) },
	{ UnfoldInfo = unfold_info(A,B,CostDelta,D,E,F,G,H,I,J) },
	pd_info_set_unfold_info(UnfoldInfo).
pd_info_set_local_term_info(TermInfo) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(A,B,C,_,E,F,G,H,I,J) },
	{ UnfoldInfo = unfold_info(A,B,C,TermInfo,E,F,G,H,I,J) },
	pd_info_set_unfold_info(UnfoldInfo).
pd_info_set_pred_info(PredInfo) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(A,B,C,D,_,F,G,H,I,J) },
	{ UnfoldInfo = unfold_info(A,B,C,D,PredInfo,F,G,H,I,J) },
	pd_info_set_unfold_info(UnfoldInfo).
pd_info_set_parents(Parents) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(A,B,C,D,E,_,G,H,I,J) },
	{ UnfoldInfo = unfold_info(A,B,C,D,E,Parents,G,H,I,J) },
	pd_info_set_unfold_info(UnfoldInfo).
pd_info_set_pred_proc_id(PredProcId) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(A,B,C,D,E,F,_,H,I,J) },
	{ UnfoldInfo = unfold_info(A,B,C,D,E,F,PredProcId,H,I,J) },
	pd_info_set_unfold_info(UnfoldInfo).
pd_info_set_changed(Changed) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(A,B,C,D,E,F,G,_,I,J) },
	{ UnfoldInfo = unfold_info(A,B,C,D,E,F,G, Changed, I,J) },
	pd_info_set_unfold_info(UnfoldInfo).
pd_info_set_size_delta(SizeDelta) -->
	pd_info_get_unfold_info(UnfoldInfo0),
	{ UnfoldInfo0 = unfold_info(A,B,C,D,E,F,G,H,_,J) },
	{ UnfoldInfo = unfold_info(A,B,C,D,E,F,G,H, SizeDelta, J) },
	pd_info_set_unfold_info(UnfoldInfo).

pd_info_incr_cost_delta(Delta1) -->
	pd_info_get_cost_delta(Delta0),
	{ Delta is Delta0 + Delta1 },
	pd_info_set_cost_delta(Delta).

pd_info_incr_size_delta(Delta1) -->
	pd_info_get_size_delta(Delta0),
	{ Delta is Delta0 + Delta1 },
	pd_info_set_size_delta(Delta).

%-----------------------------------------------------------------------------%
%-----------------------------------------------------------------------------%

:- interface.

:- pred pd_info__search_version(hlds_goal::in, maybe_version::out,
	pd_info::pd_info_di, pd_info::pd_info_uo) is det.

:- pred pd_info__define_new_pred(hlds_goal::in, pred_proc_id::out,
	hlds_goal::out, pd_info::pd_info_di, pd_info::pd_info_uo) is det.

	% Add a version to the table.
:- pred pd_info__register_version(pred_proc_id::in, version_info::in,
	pd_info::pd_info_di, pd_info::pd_info_uo) is det.

	% Remove a version and make sure it is never recreated.
:- pred pd_info__invalidate_version(pred_proc_id::in,
	pd_info::pd_info_di, pd_info::pd_info_uo) is det.

	% Remove a version, but allow it to be recreated if it 
	% is used elsewhere.
:- pred pd_info__remove_version(pred_proc_id::in,
	pd_info::pd_info_di, pd_info::pd_info_uo) is det.

	% The result of looking up a specialised version of a pred.
:- type maybe_version
	--->	no_version
	;	version(
			version_is_exact, 
			pred_proc_id, 
			version_info,
			map(var, var),	% renaming of the version info
			tsubst		% var types substitution
		).

:- type version_is_exact
	--->	exact
	;	more_general.

:- type version_info
	---> version_info(
		hlds_goal,		% goal before unfolding.
		list(pred_proc_id),	% calls being deforested. 
		list(var),		% arguments.
		list(type),		% argument types.
		instmap,		% initial insts of the nonlocals.
		int,			% cost of the original goal.
		int,			% improvement in cost.
		set(pred_proc_id), 	% parent versions.
		maybe(pred_proc_id)	% the version which was generalised
					% to produce this version.
	).

%-----------------------------------------------------------------------------%

:- implementation.

pd_info__search_version(Goal, MaybeVersion) -->
	pd_debug__output_goal("Searching for version:\n", Goal),
	{ pd_util__goal_get_calls(Goal, CalledPreds) },
	pd_info_get_versions(Versions),
	pd_info_get_goal_version_index(GoalVersionIndex),
	pd_info_get_module_info(ModuleInfo),
	pd_info_get_proc_info(ProcInfo),
	pd_info_get_instmap(InstMap),
	{ proc_info_vartypes(ProcInfo, VarTypes) },
	(
		{ map__search(GoalVersionIndex, CalledPreds, VersionIds) },
		{ pd_info__get_matching_version(ModuleInfo, Goal, InstMap,
			VarTypes, VersionIds, Versions, MaybeVersion0) }
	->
		{ MaybeVersion = MaybeVersion0 }
	;
		{ MaybeVersion = no_version }
	),
	pd_debug__search_version_result(MaybeVersion).

%-----------------------------------------------------------------------------%

:- pred pd_info__get_matching_version(module_info::in, hlds_goal::in,
		instmap::in, map(var, type)::in, list(pred_proc_id)::in, 
		version_index::in, maybe_version::out) is semidet.

pd_info__get_matching_version(_, _, _, _, [], _, no_version).
pd_info__get_matching_version(ModuleInfo, ThisGoal, ThisInstMap, VarTypes, 
		[VersionId | VersionIds], Versions, MaybeVersion) :-
	map__lookup(Versions, VersionId, Version),
	Version = version_info(OldGoal, _, OldArgs, OldArgTypes,
			OldInstMap, _, _, _, _),
	(
		pd_info__goal_is_more_general(ModuleInfo,
			OldGoal, OldInstMap, OldArgs, OldArgTypes, 
			ThisGoal, ThisInstMap, VarTypes, VersionId, Version, 
			MaybeVersion1)
	->
		(
			MaybeVersion1 = no_version,
			pd_info__get_matching_version(ModuleInfo,
				ThisGoal, ThisInstMap, VarTypes, VersionIds,
				Versions, MaybeVersion)
		;
			MaybeVersion1 = version(exact, _, _, _, _),
			MaybeVersion = MaybeVersion1
		;
			MaybeVersion1 =
				version(more_general, PredProcId,
					MoreGeneralVersion, Renaming, TypeSubn),
			pd_info__get_matching_version(ModuleInfo, ThisGoal, 
				ThisInstMap, VarTypes, VersionIds, 
				Versions, MaybeVersion2),
			pd_info__pick_version(ModuleInfo, PredProcId, Renaming,
				TypeSubn, MoreGeneralVersion, MaybeVersion2,
				MaybeVersion)
		)
	;
		pd_info__get_matching_version(ModuleInfo, ThisGoal,
			ThisInstMap, VarTypes, VersionIds,
			Versions, MaybeVersion)
	).

%-----------------------------------------------------------------------------%

	% Choose between two versions.
:- pred pd_info__pick_version(module_info::in, pred_proc_id::in, 
		map(var, var)::in, tsubst::in, version_info::in, 
		maybe_version::in, maybe_version::out) is det.

pd_info__pick_version(_, PredProcId, Renaming, TSubn, VersionInfo, no_version,
	version(more_general, PredProcId, VersionInfo, Renaming, TSubn)).
pd_info__pick_version(_, _, _, _, _, 
		version(exact, PredProcId, Version2, Renaming2, TSubn2),
		version(exact, PredProcId, Version2, Renaming2, TSubn2)).
pd_info__pick_version(_ModuleInfo, PredProcId1, Renaming1, TSubn1, Version1,
		version(more_general, PredProcId2, Version2, Renaming2, TSubn2),
		MaybeVersion) :-
	Version1 = version_info(_, _, _, _, _, _, CostDelta1, _, _),
	Version2 = version_info(_, _, _, _, _, _, CostDelta2, _, _),
	% Select the version with the biggest decrease in cost.
	( CostDelta1 > CostDelta2 ->
		MaybeVersion = version(more_general, PredProcId1, 
				Version1, Renaming1, TSubn1)
	;
		MaybeVersion = version(more_general, PredProcId2, 
				Version2, Renaming2, TSubn2)
	).

%-----------------------------------------------------------------------------%

	% The aim of this is to check whether the first goal can be used
	% instead of the second if specialisation on the second goal does
	% not produce any more improvement.
	%
	% An old version is more general than a new one if:
	% - the goals have the same "shape" (see pd_util__goals_match).
	% - each variable in the old goal maps to exactly one
	% 	variable in the new (multiple vars in the new goal can
	% 	map to one var in the old).
	% - each nonlocal in the new goal maps to a non-local in the
	% 	old (i.e. the old version produces all the variables 
	% 	that the new one does).
	% - for each pair of corresponding insts in the above mapping,
	%	the old inst must be at least as general as the
	% 	new one, i.e inst_matches_initial(FirstInst, SecondInst) (?)
	% 
:- pred pd_info__goal_is_more_general(module_info::in, hlds_goal::in, 
	instmap::in, list(var)::in, list(type)::in, hlds_goal::in, instmap::in,
	map(var, type)::in, pred_proc_id::in, 
	version_info::in, maybe_version::out) is semidet.

pd_info__goal_is_more_general(ModuleInfo, OldGoal, OldInstMap, OldArgs, 
		OldArgTypes, NewGoal, NewInstMap, NewVarTypes, PredProcId, 
		Version, MaybeVersion) :-
	pd_util__goals_match(ModuleInfo, OldGoal, OldArgs, OldArgTypes, 
		NewGoal, NewVarTypes, OldNewRenaming, TypeRenaming), 
	OldGoal = _ - OldGoalInfo,
	goal_info_get_nonlocals(OldGoalInfo, OldNonLocals0),
	set__to_sorted_list(OldNonLocals0, OldNonLocalsList),
	pd_info__check_insts(ModuleInfo, OldNonLocalsList, OldNewRenaming, 
		OldInstMap, NewInstMap, exact, Exact),
		
	MaybeVersion = version(Exact, PredProcId, Version, 
		OldNewRenaming, TypeRenaming).

%-----------------------------------------------------------------------------%

	% Check that all the insts in the old version are at least as
	% general as the insts in the new version.
:- pred pd_info__check_insts(module_info::in, list(var)::in, map(var, var)::in,
		instmap::in, instmap::in, version_is_exact::in, 
		version_is_exact::out) is semidet.

pd_info__check_insts(_, [], _, _, _, Exact, Exact).
pd_info__check_insts(ModuleInfo, [OldVar | Vars], VarRenaming, OldInstMap,
		NewInstMap, ExactSoFar0, ExactSoFar) :-
	instmap__lookup_var(OldInstMap, OldVar, OldVarInst),
	map__lookup(VarRenaming, OldVar, NewVar),
	instmap__lookup_var(NewInstMap, NewVar, NewVarInst),
	inst_matches_initial(NewVarInst, OldVarInst, ModuleInfo),
	( ExactSoFar0 = exact ->
		% Does inst_matches_initial(Inst1, Inst2, M) and
		% inst_matches_initial(Inst2, Inst1, M) imply that Inst1
		% and Inst2 are interchangable? 
		( inst_matches_initial(OldVarInst, NewVarInst, ModuleInfo) ->
			ExactSoFar1 = exact
		;
			ExactSoFar1 = more_general
		)
	;
		ExactSoFar1 = more_general
	),
	pd_info__check_insts(ModuleInfo, Vars, VarRenaming, OldInstMap,
		NewInstMap, ExactSoFar1, ExactSoFar).

%-----------------------------------------------------------------------------%

pd_info__define_new_pred(Goal, PredProcId, CallGoal) -->
	pd_info_get_instmap(InstMap),
	{ Goal = _ - GoalInfo },
	{ goal_info_get_nonlocals(GoalInfo, NonLocals) },
	{ set__to_sorted_list(NonLocals, Args) },
	pd_info_get_counter(Counter0),
	{ Counter is Counter0 + 1 },
	pd_info_set_counter(Counter),
	{ string__format("_mercury_deforestation__%i", [i(Counter0)], Name) },
	pd_info_get_proc_info(ProcInfo),
	pd_info_get_pred_info(PredInfo),
	{ pred_info_typevarset(PredInfo, TVarSet) },
	{ pred_info_get_markers(PredInfo, Markers) },
	{ pred_info_get_class_context(PredInfo, ClassContext) },
	{ proc_info_varset(ProcInfo, VarSet) },
	{ proc_info_vartypes(ProcInfo, VarTypes) },
	{ proc_info_typeinfo_varmap(ProcInfo, TVarMap) },
	{ proc_info_typeclass_info_varmap(ProcInfo, TCVarMap) },
	pd_info_get_module_info(ModuleInfo0),
	{ hlds_pred__define_new_pred(Goal, CallGoal, Args, InstMap, 
		Name, TVarSet, VarTypes, ClassContext, TVarMap, TCVarMap,
		VarSet, Markers, ModuleInfo0, ModuleInfo, PredProcId) },
	pd_info_set_module_info(ModuleInfo).

%-----------------------------------------------------------------------------%

pd_info__register_version(PredProcId, Version) -->
	pd_debug__register_version(PredProcId, Version),
	pd_info_get_goal_version_index(GoalVersionIndex0),
	{ Version = version_info(Goal, _, _, _, _, _, _, _, _) },
	{ pd_util__goal_get_calls(Goal, Calls) },
	{ map__search(GoalVersionIndex0, Calls, VersionList0) ->
		map__det_update(GoalVersionIndex0, Calls,
			[PredProcId | VersionList0], GoalVersionIndex)
	;
		map__set(GoalVersionIndex0, Calls, [PredProcId], 
			GoalVersionIndex)
	},
	pd_info_set_goal_version_index(GoalVersionIndex),
	pd_info_get_versions(Versions0),
	{ map__det_insert(Versions0, PredProcId, Version, Versions) },
	pd_info_set_versions(Versions),
	pd_info_get_created_versions(CreatedVersions0),
	{ set__insert(CreatedVersions0, PredProcId, CreatedVersions) },
	pd_info_set_created_versions(CreatedVersions).

%-----------------------------------------------------------------------------%

pd_info__invalidate_version(PredProcId) -->
	pd_info_get_versions(Versions0),
	{ map__lookup(Versions0, PredProcId, Version) },
	{ Version = version_info(Goal, _, _, _, _, _, _, _, _) },
	{ pd_util__goal_get_calls(Goal, Calls) },
	( { Calls = [FirstCall | _], list__last(Calls, LastCall) } ->
			% Make sure we never create another version to
			% deforest this pair of calls.
		pd_info_get_useless_versions(Useless0),
		{ set__insert(Useless0, FirstCall - LastCall, Useless) },
		pd_info_set_useless_versions(Useless)
	;
		[]
	),
	pd_info__remove_version(PredProcId).

pd_info__remove_version(PredProcId) -->
	pd_info_get_versions(Versions0),
	{ map__lookup(Versions0, PredProcId, Version) },
	{ Version = version_info(Goal, _, _, _, _, _, _, _, _) },
	{ pd_util__goal_get_calls(Goal, Calls) },
	{ map__delete(Versions0, PredProcId, Versions) },
	pd_info_set_versions(Versions),

	pd_info_get_goal_version_index(GoalIndex0),
	( { map__search(GoalIndex0, Calls, GoalVersions0) } ->
		{ list__delete_all(GoalVersions0, PredProcId, GoalVersions) },
		{ map__det_update(GoalIndex0, Calls, 
			GoalVersions, GoalIndex) },
		pd_info_set_goal_version_index(GoalIndex)
	;
		[]
	),

	pd_info_get_created_versions(CreatedVersions0),
	{ set__delete(CreatedVersions0, PredProcId, CreatedVersions) },
	pd_info_set_created_versions(CreatedVersions),

	pd_info_get_module_info(ModuleInfo0),
	{ PredProcId = proc(PredId, _) },
	{ module_info_remove_predicate(PredId, ModuleInfo0, ModuleInfo) },
	pd_info_set_module_info(ModuleInfo).

%-----------------------------------------------------------------------------%
%-----------------------------------------------------------------------------%


%-----------------------------------------------------------------------------%
% Copyright (C) 1998 University of Melbourne.
% This file may only be copied under the terms of the GNU General
% Public License - see the file COPYING in the Mercury distribution.
%-----------------------------------------------------------------------------%
% File: pd_term.m
% Main author: stayl
%
% Termination checking for partial deduction / deforestation.
%
%
% For conjunctions, count up the length of the conjunction.
% For each pair of calls on the end of the conjunction,
% this length must decrease for the check to succeed. 
%
% For single calls, the first call records the sizes of the insts
% of all the arguments. If the total size of a later call increases,
% the increasing arguments are removed from the record. If there are 
% no decreasing arguments, the termination check fails. Otherwise
% the check succeeds and the new argument sizes are recorded.
%
% There are many possible improvements to this:
% - Partition on subterms of arguments rather than whole arguments - useful
% 	when partially instantiated structures are present.
% - Use homeomorphic embedding instead of term sizes as suggested in
% 	the papers on partial deduction from K.U. Leuven. This will be 
% 	useful (necessary?) if we start propagating equality constraints.
%
%-----------------------------------------------------------------------------%
:- module pd_term.

:- interface.

:- import_module hlds_goal, hlds_module, instmap, pd_info.

:- pred pd_term__global_check(module_info::in, hlds_goal::in,
		list(hlds_goal)::in, hlds_goal::in,
		instmap::in, version_index::in, 
		global_term_info::in, global_term_info::out, 
		global_check_result::out) is det.

:- type global_check_result
	--->	ok(pair(pred_proc_id), int)
	;	possible_loop(pair(pred_proc_id), int, pred_proc_id)
	;	loop.

:- pred pd_term__local_check(module_info::in, hlds_goal::in,
	instmap::in, local_term_info::in, local_term_info::out) is semidet.

:- pred pd_term__global_term_info_init(global_term_info::out) is det.

:- pred pd_term__local_term_info_init(local_term_info::out) is det.

:- pred pd_term__get_proc_term_info(local_term_info::in, pred_proc_id::in,
		pd_proc_term_info::out) is semidet.

	% Update the global termination information when we find
	% out the pred_proc_id that has been assigned to a version.
:- pred pd_term__update_global_term_info(global_term_info::in,
		pair(pred_proc_id)::in, pred_proc_id::in, int::in,
		global_term_info::out) is det.

:- type global_term_info. 
:- type local_term_info.
:- type pd_proc_term_info.

%-----------------------------------------------------------------------------%
:- implementation.

:- import_module hlds_pred, (inst), mode_util, prog_data, pd_util.
:- import_module assoc_list, bool, int, list, map, require, set, std_util, term.

:- type global_term_info
	--->	global_term_info(
			single_covering_goals,
			multiple_covering_goals
		).

	% We only deal with single atoms while unfolding.
:- type local_term_info == single_covering_goals.

	% For single goals, use the argument partition method.
:- type single_covering_goals == map(pred_proc_id, pd_proc_term_info).

	% Map from a pair of procedures at the end of the conjunction
	% to be deforested and the most recent ancestor with this pair
	% of goals.
:- type multiple_covering_goals == map(pair(pred_proc_id), 
					pair(int, maybe(pred_proc_id))).

		% Mapping from argument to size.
:- type pd_proc_term_info	== 	assoc_list(int, int).
	
%-----------------------------------------------------------------------------%

pd_term__global_term_info_init(TermInfo) :-
	map__init(SingleGoals),
	map__init(MultipleGoals),
	TermInfo = global_term_info(SingleGoals, MultipleGoals).

pd_term__local_term_info_init(TermInfo) :-
	map__init(TermInfo).

pd_term__get_proc_term_info(TermInfo, PredProcId, ProcTermInfo) :-
	map__search(TermInfo, PredProcId, ProcTermInfo).

%-----------------------------------------------------------------------------%
:- import_module string.

pd_term__global_check(_ModuleInfo, EarlierGoal, BetweenGoals, LaterGoal, 
		_InstMap, Versions, Info0, Info, Result) :-
	Info0 = global_term_info(SingleGoalCover0, MultipleGoalCover0),
	(
		EarlierGoal = call(PredId1, ProcId1, _, _, _, _) - _,
		LaterGoal = call(PredId2, ProcId2, _, _, _, _) - _,
		Hd = lambda([List::in, Head::out] is semidet, 
			List = [Head | _]),
		expand_calls(Hd, Versions, proc(PredId1, ProcId1), 
			FirstPredProcId),
		expand_calls(list__last, Versions, proc(PredId2, ProcId2), 
			LastPredProcId)
	->
		ProcPair = FirstPredProcId - LastPredProcId,
		list__length(BetweenGoals, Length),
		( 
			map__search(MultipleGoalCover0, ProcPair, 
				MaxLength - MaybeCoveringPredProcId) 
		->
			(
				Length < MaxLength 
			->
				Result = ok(ProcPair, Length),
					% set the maybe(pred_proc_id)
					% when we create the new predicate
				map__set(MultipleGoalCover0, ProcPair,
					Length - no, MultipleGoalCover)
			;
				Length = MaxLength, 
				MaybeCoveringPredProcId = 
					yes(CoveringPredProcId) 
			->
				% If the goals match, check that the
				% argument insts decrease.
				% If not, we may need to do a 
				% generalisation step.
				Result = possible_loop(ProcPair, Length,
						CoveringPredProcId),
				MultipleGoalCover = MultipleGoalCover0
			;
				Result = loop,
				MultipleGoalCover = MultipleGoalCover0
			)
		;
			% We haven't seen this pair before, so it must
			% be okay to specialise.
			Result = ok(ProcPair, Length),

			% set the maybe(pred_proc_id)
			% when we create the new predicate
			map__set(MultipleGoalCover0, ProcPair,
				Length - no, MultipleGoalCover)
		),
		SingleGoalCover = SingleGoalCover0
	;
		error("pd_term__global_check")
	),
	Info = global_term_info(SingleGoalCover, MultipleGoalCover).

	% We don't want to use folded calls to parent versions 
	% when doing the global termination check, since that 
	% could give a sequence:
	%	old ....pred1
	% 	new1 .... pred1
	% 	new2 ....... pred1
	% 	new3 ......... pred1
	% Instead, we expand to predicates from the original program, 
	% which must contain a finite number of pairs of pred_proc_ids.
:- pred expand_calls(pred(list(pred_proc_id), pred_proc_id), version_index,
		pred_proc_id, pred_proc_id).
:- mode expand_calls(pred(in, out) is semidet, in, in, out) is semidet.

expand_calls(GetEnd, Versions, PredProcId0, PredProcId) :-
	( map__search(Versions, PredProcId0, VersionInfo) ->
		VersionInfo = version_info(_, Calls, _, _, _, _, _, _, _),
		call(GetEnd, Calls, PredProcId1),
		expand_calls(GetEnd, Versions, PredProcId1, PredProcId)
	;
		PredProcId = PredProcId0	
	).

%-----------------------------------------------------------------------------%

pd_term__local_check(ModuleInfo, Goal1, InstMap, Cover0, Cover) :-
	Goal1 = call(PredId, ProcId, Args, _, _, _) - _,
	( map__search(Cover0, proc(PredId, ProcId), CoveringInstSizes0) ->
		pd_term__do_local_check(ModuleInfo, InstMap, Args,
			CoveringInstSizes0, CoveringInstSizes),
		map__set(Cover0, proc(PredId, ProcId),
			CoveringInstSizes, Cover)
	;
		pd_term__initial_sizes(ModuleInfo, InstMap, 
			Args, 1, ArgInstSizes),
		map__set(Cover0, proc(PredId, ProcId), 
			ArgInstSizes, Cover)
	).

:- pred pd_term__do_local_check(module_info::in, instmap::in, 
		list(var)::in, assoc_list(int, int)::in, 
		assoc_list(int, int)::out) is semidet.

pd_term__do_local_check(ModuleInfo, InstMap, Args, OldSizes, NewSizes) :-
	pd_term__get_matching_sizes(ModuleInfo, InstMap, Args, 
		OldSizes, NewSizes1, OldTotal, NewTotal),
	( NewTotal < OldTotal ->
		NewSizes = NewSizes1
	;
		pd_term__split_out_non_increasing(OldSizes, NewSizes1, 
			yes, NewSizes)
	).

%-----------------------------------------------------------------------------%

pd_term__update_global_term_info(TermInfo0, ProcPair, 
		PredProcId, Size, TermInfo) :-
	TermInfo0 = global_term_info(Single, Multiple0),
	map__set(Multiple0, ProcPair, Size - yes(PredProcId), Multiple),
	TermInfo = global_term_info(Single, Multiple).

%-----------------------------------------------------------------------------%

:- pred pd_term__initial_sizes(module_info::in, instmap::in, list(var)::in,
		int::in, assoc_list(int, int)::out) is det.

pd_term__initial_sizes(_, _, [], _, []).
pd_term__initial_sizes(ModuleInfo, InstMap, [Arg | Args], ArgNo, 
		[ArgNo - Size | Sizes]) :-
	NextArgNo is ArgNo + 1,
	pd_term__initial_sizes(ModuleInfo, InstMap, Args, NextArgNo, Sizes),
	instmap__lookup_var(InstMap, Arg, ArgInst),
	pd_util__inst_size(ModuleInfo, ArgInst, Size).

%-----------------------------------------------------------------------------%

:- pred pd_term__get_matching_sizes(module_info::in, instmap::in, 
		list(var)::in, assoc_list(int, int)::in, 
		assoc_list(int, int)::out, int::out, int::out) is det.

pd_term__get_matching_sizes(_, _, _, [], [], 0, 0).
pd_term__get_matching_sizes(ModuleInfo, InstMap, Args, 
		[ArgNo - OldSize | OldSizes], [ArgNo - NewSize | NewSizes], 
		OldTotal, NewTotal) :-
	pd_term__get_matching_sizes(ModuleInfo, InstMap, Args,
		OldSizes, NewSizes, OldTotal1, NewTotal1),
	list__index1_det(Args, ArgNo, Arg),
	instmap__lookup_var(InstMap, Arg, ArgInst),
	pd_util__inst_size(ModuleInfo, ArgInst, NewSize),
	OldTotal = OldTotal1 + OldSize,
	NewTotal = NewTotal1 + NewSize.
	
%-----------------------------------------------------------------------------%

:- pred pd_term__split_out_non_increasing(assoc_list(int, int)::in,
		assoc_list(int, int)::in, bool::out,
		assoc_list(int, int)::out) is semidet.

pd_term__split_out_non_increasing([], [], no, []).
pd_term__split_out_non_increasing([_|_], [], _, _) :-
	error("pd_term__split_out_non_increasing").
pd_term__split_out_non_increasing([], [_|_], _, _) :-
	error("pd_term__split_out_non_increasing").
pd_term__split_out_non_increasing([Arg - OldSize | Args0], 
		[_ - NewSize | Args], FoundDecreasing, NonIncreasing) :-
	pd_term__split_out_non_increasing(Args0, Args,
		FoundDecreasing1, NonIncreasing1),
	( NewSize =< OldSize ->
		NonIncreasing = [Arg - NewSize | NonIncreasing1],
		( NewSize = OldSize ->
			FoundDecreasing = no
		;
			FoundDecreasing = yes
		)
	;
		NonIncreasing = NonIncreasing1,
		FoundDecreasing = FoundDecreasing1
	).

%-----------------------------------------------------------------------------%
%-----------------------------------------------------------------------------%



%-----------------------------------------------------------------------------%
% Copyright (C) 1998 University of Melbourne.
% This file may only be copied under the terms of the GNU General
% Public License - see the file COPYING in the Mercury distribution.
%-----------------------------------------------------------------------------%
% File pd_util.m
% Main author: stayl.
%
% Utility predicates for deforestation and partial evaluation.
%
%-----------------------------------------------------------------------------%
:- module pd_util.

:- interface.

:- import_module pd_info, hlds_goal, hlds_module, hlds_pred, mode_errors.
:- import_module simplify.
:- import_module list, set.

	% Pick out the pred_proc_ids of the calls in a list of atomic goals.
:- pred pd_util__goal_get_calls(hlds_goal::in,
		list(pred_proc_id)::out) is det.

:- pred pd_util__simplify_goal(simplify::in, hlds_goal::in, hlds_goal::out,
		pd_info::pd_info_di, pd_info::pd_info_uo) is det.

:- pred pd_util__unique_modecheck_goal(hlds_goal::in, hlds_goal::out,
		list(mode_error_info)::out, pd_info::pd_info_di, 
		pd_info::pd_info_uo) is det.

:- pred pd_util__unique_modecheck_goal(set(var)::in, hlds_goal::in, 
		hlds_goal::out, list(mode_error_info)::out, 
		pd_info::pd_info_di, pd_info::pd_info_uo) is det.

	% Find out which arguments of the procedure are interesting
	% for deforestation.
:- pred pd_util__get_branch_vars_proc(pred_proc_id::in, proc_info::in, 
		pd_arg_info::in, pd_arg_info::out,
		module_info::in, module_info::out) is det.

	% Find out which variables of the goal are interesting
	% for deforestation.
:- pred pd_util__get_branch_vars_goal(hlds_goal::in, 
		maybe(pd_branch_info(var))::out, pd_info::pd_info_di,
		pd_info::pd_info_uo) is det.

:- pred pd_util__requantify_goal(hlds_goal::in, set(var)::in, hlds_goal::out,
		pd_info::pd_info_di, pd_info::pd_info_uo) is det.

:- pred pd_util__recompute_instmap_delta(hlds_goal::in, hlds_goal::out, 
		pd_info::pd_info_di, pd_info::pd_info_uo) is det.

	% Convert from information about the argument positions to 
	% information about the argument variables.
:- pred pd_util__convert_branch_info(pd_branch_info(int)::in, list(var)::in,
		pd_branch_info(var)::out) is det.	

	% inst_msg(InstA, InstB, InstC):
	% 	Take the most specific generalisation of two insts.
	%       The information in InstC is the minimum of the
	%       information in InstA and InstB.  Where InstA and
	%       InstB specify a binding (free or bound), it must be
	%       the same in both.
	% 	The uniqueness of the final inst is taken from InstB.
	% 	The difference between inst_merge and inst_msg is that the 
	% 	msg of `bound([functor, []])' and `bound([another_functor, []])'
	%	is `ground' rather than `bound([functor, another_functor])'. 
	% 	Also the msgs are not tabled, so the module_info is not
	% 	threaded through.
	% 	If an inst is "rounded off", it must not contain `any' insts
	% 	and must be completely unique or completely non-unique.
	% 	This is used in generalisation to avoid non-termination
	% 	of deforestation - InstA is the inst in an old version,
	% 	we are taking the msg with to avoid non-termination,
	% 	InstB is the inst in the new version we want to create.
	%	It is always safe for inst_msg to fail - this will just
	% 	result in less optimization.
	% 	Mode analysis should be run on the goal to
	%	check that this doesn't introduce mode errors, since
	% 	the information that was removed may actually have been
	%	necessary for mode correctness.
:- pred inst_msg(inst, inst, module_info, inst).
:- mode inst_msg(in, in, in, out) is semidet.

:- pred pd_util__inst_list_size(module_info::in, list(inst)::in,
		set(inst_name)::in, int::in, int::out) is det.

:- pred pd_util__inst_size(module_info::in, (inst)::in, int::out) is det.

	% pd_util__goals_match(ModuleInfo, OldGoal, OldArgs, NewGoal,
	% 		OldToNewRenaming)
	%
	% Check the shape of the goals, and return a mapping from
	% variables in the old goal to variables in the new and
	% a substitution to apply to the types. This only
	% attempts to match `simple' lists of goals, which contain
	% only conj, some, not and atomic goals, since deforest.m
	% only attempts to optimize those types of conjunctions.
:- pred pd_util__goals_match(module_info::in, hlds_goal::in, list(var)::in,
		list(type)::in, hlds_goal::in, map(var, type)::in,
		map(var, var)::out, substitution::out) is semidet.

	% pd_util__can_reorder_goals(ModuleInfo, FullyStrict, Goal1, Goal2).
	%
	% Goals can be reordered if
	% - the goals are independent
	% - the goals are pure
	% - any possible change in termination behaviour is allowed
	% 	according to the semantics options.
:- pred pd_util__can_reorder_goals(module_info::in, bool::in, hlds_goal::in,
		hlds_goal::in) is semidet.

	% pd_util__reordering_maintains_termination(FullyStrict, Goal1, Goal2)
	%
	% Succeeds if any possible change in termination behaviour from
	% reordering the goals is allowed according to the semantics options.
:- pred pd_util__reordering_maintains_termination(module_info::in, bool::in, 
		hlds_goal::in, hlds_goal::in) is semidet.

%-----------------------------------------------------------------------------%
:- implementation.

:- import_module pd_cost, hlds_data, instmap, mode_util, prog_data.
:- import_module unused_args, inst_match, (inst), quantification, mode_util.
:- import_module code_aux, purity, mode_info, unique_modes.
:- import_module type_util, det_util, options.
:- import_module assoc_list, bool, int, list, map.
:- import_module require, set, std_util, term.

pd_util__goal_get_calls(Goal0, CalledPreds) :-
	goal_to_conj_list(Goal0, GoalList),
	GetCalls = lambda([Goal::in, CalledPred::out] is semidet, (
			Goal = call(PredId, ProcId, _, _, _, _) - _,
			CalledPred = proc(PredId, ProcId)
		)),
	list__filter_map(GetCalls, GoalList, CalledPreds).

%-----------------------------------------------------------------------------%

pd_util__simplify_goal(Simplify, Goal0, Goal) -->
	%
	% Construct a simplify_info.
	% 
	pd_info_get_module_info(ModuleInfo0),
	{ module_info_globals(ModuleInfo0, Globals) },
	pd_info_get_pred_proc_id(proc(PredId, ProcId)),
	{ det_info_init(ModuleInfo0, PredId, ProcId,
		Globals, DetInfo0) },
	pd_info_get_instmap(InstMap0),
	pd_info_get_proc_info(ProcInfo0),
	{ proc_info_varset(ProcInfo0, VarSet0) },
	{ proc_info_vartypes(ProcInfo0, VarTypes0) },
	{ simplify_info_init(DetInfo0, Simplify, InstMap0,
		VarSet0, VarTypes0, SimplifyInfo0) },

	{ simplify__process_goal(Goal0, Goal, SimplifyInfo0, SimplifyInfo) },

	%
	% Deconstruct the simplify_info.
	%
	{ simplify_info_get_module_info(SimplifyInfo, ModuleInfo) },
	{ simplify_info_get_varset(SimplifyInfo, VarSet) },
	{ simplify_info_get_var_types(SimplifyInfo, VarTypes) },
	{ simplify_info_get_cost_delta(SimplifyInfo, CostDelta) },
	pd_info_get_proc_info(ProcInfo1),
	{ proc_info_set_varset(ProcInfo1, VarSet, ProcInfo2) },
	{ proc_info_set_vartypes(ProcInfo2, VarTypes, ProcInfo) },
	pd_info_set_proc_info(ProcInfo),
	pd_info_incr_cost_delta(CostDelta),
	pd_info_set_module_info(ModuleInfo).

%-----------------------------------------------------------------------------%

pd_util__unique_modecheck_goal(Goal0, Goal, Errors) -->
	pd_util__get_goal_live_vars(Goal0, LiveVars),
	pd_util__unique_modecheck_goal(LiveVars, Goal0, Goal, Errors).

pd_util__unique_modecheck_goal(LiveVars, Goal0, Goal, Errors) -->

	% 
	% Construct a mode_info.
	%
	pd_info_get_pred_proc_id(PredProcId),
	{ PredProcId = proc(PredId, ProcId) },
	pd_info_get_module_info(ModuleInfo0),
	pd_info_get_instmap(InstMap0),
	{ term__context_init(Context) },
	pd_info_get_io_state(IO0),
	pd_info_get_pred_info(PredInfo0),
	pd_info_get_proc_info(ProcInfo0),
	{ module_info_set_pred_proc_info(ModuleInfo0, PredId, ProcId,
		PredInfo0, ProcInfo0, ModuleInfo1) },
	{ mode_info_init(IO0, ModuleInfo1, PredId, ProcId, Context,
		LiveVars, InstMap0, ModeInfo0) },

	{ unique_modes__check_goal(Goal0, Goal, ModeInfo0, ModeInfo1) },
	pd_info_lookup_bool_option(debug_pd, Debug),
	{ Debug = yes ->
		report_mode_errors(ModeInfo1, ModeInfo)
	;
		ModeInfo = ModeInfo1
	},
	{ mode_info_get_errors(ModeInfo, Errors) },

	%
	% Deconstruct the mode_info.
	%
	{ mode_info_get_module_info(ModeInfo, ModuleInfo) },
	{ mode_info_get_io_state(ModeInfo, IO) },
	{ mode_info_get_varset(ModeInfo, VarSet) },
	{ mode_info_get_var_types(ModeInfo, VarTypes) },
	pd_info_set_module_info(ModuleInfo),
	{ module_info_pred_proc_info(ModuleInfo, PredId, ProcId,
		PredInfo, ProcInfo1) },
	pd_info_set_pred_info(PredInfo),
	{ proc_info_set_varset(ProcInfo1, VarSet, ProcInfo2) },
	{ proc_info_set_vartypes(ProcInfo2, VarTypes, ProcInfo) },
	pd_info_set_proc_info(ProcInfo),
	pd_info_set_io_state(IO).

	% Work out which vars are live later in the computation based
	% on which of the non-local variables are not clobbered by the goal.
:- pred pd_util__get_goal_live_vars(hlds_goal::in, set(var)::out, 
		pd_info::pd_info_di, pd_info::pd_info_uo) is det.

pd_util__get_goal_live_vars(_ - GoalInfo, Vars) -->
	pd_info_get_module_info(ModuleInfo),
	{ goal_info_get_instmap_delta(GoalInfo, InstMapDelta) },
	pd_info_get_instmap(InstMap),
	{ goal_info_get_nonlocals(GoalInfo, NonLocals) },
	{ set__to_sorted_list(NonLocals, NonLocalsList) },
	{ set__init(Vars0) },
	{ get_goal_live_vars_2(ModuleInfo, NonLocalsList, InstMap,
		InstMapDelta, Vars0, Vars) }.

:- pred pd_util__get_goal_live_vars_2(module_info::in, list(var)::in,
	instmap::in, instmap_delta::in, set(var)::in, set(var)::out) is det.

pd_util__get_goal_live_vars_2(_, [], _, _, Vars, Vars).
pd_util__get_goal_live_vars_2(ModuleInfo, [NonLocal | NonLocals], 
		InstMap, InstMapDelta, Vars0, Vars) :-
	( instmap_delta_search_var(InstMapDelta, NonLocal, FinalInst0) ->
		FinalInst = FinalInst0
	;
		instmap__lookup_var(InstMap, NonLocal, FinalInst)
	),
	( inst_is_clobbered(ModuleInfo, FinalInst) ->
		Vars1 = Vars0
	;
		set__insert(Vars0, NonLocal, Vars1)
	),
	pd_util__get_goal_live_vars_2(ModuleInfo, NonLocals, 
		InstMap, InstMapDelta, Vars1, Vars).

%-----------------------------------------------------------------------------%

pd_util__convert_branch_info(ArgInfo, Args, VarInfo) :-
	ArgInfo = pd_branch_info(ArgMap, LeftArgs, OpaqueArgs),
	map__to_assoc_list(ArgMap, ArgList),
	map__init(BranchVarMap0),
	pd_util__convert_branch_info_2(ArgList, Args,
		BranchVarMap0, BranchVarMap),

	set__to_sorted_list(LeftArgs, LeftArgNos),
	list__map(list__index1_det(Args), LeftArgNos, LeftVars0),
	set__list_to_set(LeftVars0, LeftVars),

	set__to_sorted_list(OpaqueArgs, OpaqueArgNos),
	list__map(list__index1_det(Args), OpaqueArgNos, OpaqueVars0),
	set__list_to_set(OpaqueVars0, OpaqueVars),

	VarInfo = pd_branch_info(BranchVarMap, LeftVars, OpaqueVars).

:- pred pd_util__convert_branch_info_2(assoc_list(int, set(int))::in, 
		list(var)::in, pd_var_info::in, pd_var_info::out) is det.

pd_util__convert_branch_info_2([], _, Info, Info).
pd_util__convert_branch_info_2([ArgNo - Branches | ArgInfos], Args, 
		Info0, Info) :-
	list__index1_det(Args, ArgNo, Arg),
	map__set(Info0, Arg, Branches, Info1),
	pd_util__convert_branch_info_2(ArgInfos, Args, Info1, Info).	

%-----------------------------------------------------------------------------%

:- type pd_var_info 	==	branch_info_map(var).

pd_util__get_branch_vars_proc(PredProcId, ProcInfo, 
		Info0, Info, ModuleInfo0, ModuleInfo) :-
	proc_info_goal(ProcInfo, Goal),
	instmap__init_reachable(InstMap0),
	map__init(Vars0),
	set__init(LeftVars0),
	goal_to_conj_list(Goal, GoalList),
	(
		pd_util__get_branch_vars_goal_2(ModuleInfo0, GoalList, no, 
			InstMap0, LeftVars0, LeftVars, Vars0, Vars)
	->
		proc_info_headvars(ProcInfo, HeadVars),
		map__init(ThisProcArgMap0),
		set__init(ThisProcLeftArgs0),
		pd_util__get_extra_info_headvars(HeadVars, 1, LeftVars, Vars, 
			ThisProcArgMap0, ThisProcArgMap1, 
			ThisProcLeftArgs0, ThisProcLeftArgs),
		set__init(OpaqueArgs0),
		BranchInfo0 = pd_branch_info(ThisProcArgMap1, 
				ThisProcLeftArgs, OpaqueArgs0),
		map__set(Info0, PredProcId, BranchInfo0, Info1),

			% Look for opportunities for deforestation in 
			% the sub-branches of the top-level goal.
		pd_util__get_sub_branch_vars_goal(ModuleInfo0, Info1,
			GoalList, InstMap0, Vars, AllVars, ModuleInfo),
		pd_util__get_extra_info_headvars(HeadVars, 1, LeftVars0,
			AllVars, ThisProcArgMap0, ThisProcArgMap, 
			ThisProcLeftArgs0, _),

		proc_info_argmodes(ProcInfo, ArgModes),
		pd_util__get_opaque_args(ModuleInfo, 1, ArgModes, 
			ThisProcArgMap, OpaqueArgs0, OpaqueArgs),

		BranchInfo = pd_branch_info(ThisProcArgMap, ThisProcLeftArgs,
				OpaqueArgs),
		map__set(Info1, PredProcId, BranchInfo, Info)
	;
		ModuleInfo = ModuleInfo0,
		Info = Info0
	).

	% Find output arguments about which we have no extra information,
	% such as io__states. If a later goal in a conjunction depends
	% on one of these, it is unlikely that the deforestation will
	% be able to successfully fold to give a recursive definition.
:- pred pd_util__get_opaque_args(module_info::in, int::in, list(mode)::in,
		branch_info_map(int)::in, set(int)::in, set(int)::out) is det.

pd_util__get_opaque_args(_, _, [], _, OpaqueArgs, OpaqueArgs).
pd_util__get_opaque_args(ModuleInfo, ArgNo, [ArgMode | ArgModes],
		ExtraInfoArgs, OpaqueArgs0, OpaqueArgs) :-
	( 
		mode_is_output(ModuleInfo, ArgMode),
		\+ map__contains(ExtraInfoArgs, ArgNo)
	->
		set__insert(OpaqueArgs0, ArgNo, OpaqueArgs1)
	;
		OpaqueArgs1 = OpaqueArgs0
	),
	NextArg is ArgNo + 1,
	pd_util__get_opaque_args(ModuleInfo, NextArg, ArgModes,
		ExtraInfoArgs, OpaqueArgs1, OpaqueArgs).

	% From the information about variables for which we have extra
	% information in the branches, compute the argument numbers
	% for which we have extra information.
:- pred pd_util__get_extra_info_headvars(list(var)::in, int::in,
		set(var)::in, pd_var_info::in, 
		branch_info_map(int)::in, branch_info_map(int)::out, 
		set(int)::in, set(int)::out) is det.

pd_util__get_extra_info_headvars([], _, _, _, Args, Args, LeftArgs, LeftArgs).
pd_util__get_extra_info_headvars([HeadVar | HeadVars], ArgNo, 
		LeftVars, VarInfo, ThisProcArgs0, ThisProcArgs,
		ThisProcLeftVars0, ThisProcLeftVars) :-
	( map__search(VarInfo, HeadVar, ThisVarInfo) ->
		map__det_insert(ThisProcArgs0, ArgNo,
			ThisVarInfo, ThisProcArgs1)
	;
		ThisProcArgs1 = ThisProcArgs0
	),
	( set__member(HeadVar, LeftVars) ->
		set__insert(ThisProcLeftVars0, ArgNo, ThisProcLeftVars1)
	;
		ThisProcLeftVars1 = ThisProcLeftVars0
	),
	NextArgNo is ArgNo + 1,
	pd_util__get_extra_info_headvars(HeadVars, NextArgNo,
		LeftVars, VarInfo, ThisProcArgs1, ThisProcArgs, 
		ThisProcLeftVars1, ThisProcLeftVars).

%-----------------------------------------------------------------------------%

pd_util__get_branch_vars_goal(Goal, MaybeBranchInfo) -->
	pd_info_get_module_info(ModuleInfo0),
	pd_info_get_instmap(InstMap0),
	pd_info_get_proc_arg_info(ProcArgInfo),
	{ set__init(LeftVars0) },
	{ map__init(Vars0) },
	(
		{ pd_util__get_branch_vars_goal_2(ModuleInfo0, [Goal], no, 
			InstMap0, LeftVars0, LeftVars, Vars0, Vars1) }
	->
		{ pd_util__get_sub_branch_vars_goal(ModuleInfo0, ProcArgInfo, 
			[Goal], InstMap0, Vars1, Vars, ModuleInfo) },
		pd_info_set_module_info(ModuleInfo),

			% OpaqueVars is only filled in for calls.
		{ set__init(OpaqueVars) },
		{ MaybeBranchInfo = yes(
			pd_branch_info(Vars, LeftVars, OpaqueVars)
		) }
	;
		{ MaybeBranchInfo = no }
	).

:- pred pd_util__get_branch_vars_goal_2(module_info::in, list(hlds_goal)::in, 
	bool::in, instmap::in, set(var)::in, set(var)::out,
	pd_var_info::in, pd_var_info::out) is semidet.

pd_util__get_branch_vars_goal_2(_, [], yes, _, LeftVars, LeftVars, Vars, Vars).
pd_util__get_branch_vars_goal_2(ModuleInfo, [Goal | Goals], FoundBranch0,
		InstMap0, LeftVars0, LeftVars, Vars0, Vars) :-
	Goal = _ - GoalInfo,
	goal_info_get_instmap_delta(GoalInfo, InstMapDelta),
	instmap__apply_instmap_delta(InstMap0, InstMapDelta, InstMap),
	( pd_util__get_branch_instmap_deltas(Goal, InstMapDeltas) ->
		% Only look for goals with one top-level branched goal,
		% since deforestation of goals with more than one is
		% likely to be less productive.
		FoundBranch0 = no,
		pd_util__get_branch_vars(ModuleInfo, Goal, InstMapDeltas, 
			InstMap, 1, Vars0, Vars1),
		pd_util__get_left_vars(Goal, LeftVars0, LeftVars1),
		FoundBranch = yes
	;
		Goal = GoalExpr - _,
		goal_is_atomic(GoalExpr),
		FoundBranch = FoundBranch0,
		Vars1 = Vars0,
		LeftVars1 = LeftVars0
	),
	pd_util__get_branch_vars_goal_2(ModuleInfo, Goals, FoundBranch, 
		InstMap, LeftVars1, LeftVars, Vars1, Vars).

:- pred pd_util__get_branch_instmap_deltas(hlds_goal::in, 
		list(instmap_delta)::out) is semidet.

pd_util__get_branch_instmap_deltas(Goal, [CondDelta, ThenDelta, ElseDelta]) :-
	Goal = if_then_else(_, _ - CondInfo, _ - ThenInfo,
		_ - ElseInfo, _) - _,
	goal_info_get_instmap_delta(CondInfo, CondDelta),
	goal_info_get_instmap_delta(ThenInfo, ThenDelta),
	goal_info_get_instmap_delta(ElseInfo, ElseDelta).
pd_util__get_branch_instmap_deltas(switch(_, _, Cases, _) - _,
		InstMapDeltas) :-
	GetCaseInstMapDelta =
		lambda([Case::in, InstMapDelta::out] is det, (
			Case = case(_, _ - CaseInfo),
			goal_info_get_instmap_delta(CaseInfo, InstMapDelta)
		)),
	list__map(GetCaseInstMapDelta, Cases, InstMapDeltas).
pd_util__get_branch_instmap_deltas(disj(Disjuncts, _) - _, InstMapDeltas) :-
	GetDisjunctInstMapDelta =
		lambda([Disjunct::in, InstMapDelta::out] is det, (
			Disjunct = _ - DisjInfo,
			goal_info_get_instmap_delta(DisjInfo, InstMapDelta)
		)),
	list__map(GetDisjunctInstMapDelta, Disjuncts, InstMapDeltas).


	% Get the variables for which we can do unfolding if the goals to
	% the left supply the top-level functor. Eventually this should
	% also check for if-then-elses with simple conditions.
:- pred pd_util__get_left_vars(hlds_goal::in, 
		set(var)::in, set(var)::out) is det.

pd_util__get_left_vars(Goal, Vars0, Vars) :-
	( Goal = switch(Var, _, _, _) - _ ->
		set__insert(Vars0, Var, Vars)
	;
		Vars = Vars0
	).

:- pred pd_util__get_branch_vars(module_info::in, hlds_goal::in, 
		list(instmap_delta)::in, instmap::in, int::in, 
		pd_var_info::in, pd_var_info::out) is semidet.
		
pd_util__get_branch_vars(_, _, [], _, _, Extra, Extra).
pd_util__get_branch_vars(ModuleInfo, Goal, [InstMapDelta | InstMapDeltas], 
		InstMap, BranchNo, ExtraVars0, ExtraVars) :-
	AddExtraInfoVars = 
	    lambda([ChangedVar::in, Vars0::in, Vars::out] is det, (
		(
			instmap__lookup_var(InstMap, ChangedVar, VarInst),
			instmap_delta_search_var(InstMapDelta, 
				ChangedVar, DeltaVarInst),
		    	inst_is_bound_to_functors(ModuleInfo, 
				DeltaVarInst, [_]),
		    	\+ inst_is_bound_to_functors(ModuleInfo, 
				VarInst, [_])
	    	->
			( map__search(Vars0, ChangedVar, Set0) ->
				set__insert(Set0, BranchNo, Set)
			;
				set__singleton_set(Set, BranchNo)
			),
			map__set(Vars0, ChangedVar, Set, Vars)
		;
			Vars = Vars0
		)
	    )),
	instmap_delta_changed_vars(InstMapDelta, ChangedVars),
	set__to_sorted_list(ChangedVars, ChangedVarsList),
	list__foldl(AddExtraInfoVars, ChangedVarsList, ExtraVars0, ExtraVars1),

		% We have extra information about a switched-on variable 
		% at the end of each branch.
	( Goal = switch(SwitchVar, _, _, _) - _ ->
		( map__search(ExtraVars1, SwitchVar, SwitchVarSet0) ->
			set__insert(SwitchVarSet0, BranchNo, SwitchVarSet)
		;
			set__singleton_set(SwitchVarSet, BranchNo)
		),
		map__set(ExtraVars1, SwitchVar, SwitchVarSet, ExtraVars2)
	;
		ExtraVars2 = ExtraVars1
	),
	NextBranch is BranchNo + 1,
	pd_util__get_branch_vars(ModuleInfo, Goal, InstMapDeltas, InstMap, 
		NextBranch, ExtraVars2, ExtraVars).

	% Look at the goals in the branches for extra information.
:- pred pd_util__get_sub_branch_vars_goal(module_info::in, pd_arg_info::in,
		list(hlds_goal)::in, instmap::in, branch_info_map(var)::in, 
		branch_info_map(var)::out, module_info::out) is det.

pd_util__get_sub_branch_vars_goal(Module, _, [], _, Vars, Vars, Module).
pd_util__get_sub_branch_vars_goal(ModuleInfo0, ProcArgInfo, [Goal | GoalList], 
		InstMap0, Vars0, SubVars, ModuleInfo) :-
	Goal = GoalExpr - GoalInfo,
	( GoalExpr = if_then_else(_, Cond, Then, Else, _) ->
		Cond = _ - CondInfo,
		goal_info_get_instmap_delta(CondInfo, CondDelta),
		instmap__apply_instmap_delta(InstMap0, CondDelta, InstMap1),
		goal_to_conj_list(Then, ThenList),
		pd_util__examine_branch(ModuleInfo0, ProcArgInfo, 1, ThenList,
			InstMap1, Vars0, Vars1),
		goal_to_conj_list(Else, ElseList),
		pd_util__examine_branch(ModuleInfo0, ProcArgInfo, 2, ElseList,
			InstMap0, Vars1, Vars2),
		ModuleInfo1 = ModuleInfo0
	; GoalExpr = disj(Goals, _) ->
		pd_util__examine_branch_list(ModuleInfo0, ProcArgInfo, 
			1, Goals, InstMap0, Vars0, Vars2),
		ModuleInfo1 = ModuleInfo0
	; GoalExpr = switch(Var, _, Cases, _) ->
		pd_util__examine_case_list(ModuleInfo0, ProcArgInfo, 1, Var,
			Cases, InstMap0, Vars0, Vars2, ModuleInfo1)
	;
		ModuleInfo1 = ModuleInfo0,
		Vars2 = Vars0
	),
	goal_info_get_instmap_delta(GoalInfo, InstMapDelta),
	instmap__apply_instmap_delta(InstMap0, InstMapDelta, InstMap),
	pd_util__get_sub_branch_vars_goal(ModuleInfo1, ProcArgInfo, GoalList,
		InstMap, Vars2, SubVars, ModuleInfo).

:- pred pd_util__examine_branch_list(module_info::in, pd_arg_info::in, int::in,
	list(hlds_goal)::in, instmap::in, branch_info_map(var)::in, 
	branch_info_map(var)::out) is det.

pd_util__examine_branch_list(_, _, _, [], _, Vars, Vars).
pd_util__examine_branch_list(ModuleInfo, ProcArgInfo, BranchNo, [Goal | Goals],
		InstMap, Vars0, Vars) :-
	goal_to_conj_list(Goal, GoalList),
	pd_util__examine_branch(ModuleInfo, ProcArgInfo, BranchNo, GoalList,
		InstMap, Vars0, Vars1),
	NextBranch is BranchNo + 1,
	pd_util__examine_branch_list(ModuleInfo, ProcArgInfo, NextBranch,
		Goals, InstMap, Vars1, Vars).

:- pred pd_util__examine_case_list(module_info::in, pd_arg_info::in, int::in,
	var::in, list(case)::in, instmap::in, branch_info_map(var)::in, 
	branch_info_map(var)::out, module_info::out) is det.

pd_util__examine_case_list(Module, _, _, _, [], _, Vars, Vars, Module).
pd_util__examine_case_list(ModuleInfo0, ProcArgInfo, BranchNo, Var,
		[case(ConsId, Goal) | Goals], InstMap, 
		Vars0, Vars, ModuleInfo) :-
	instmap__bind_var_to_functor(Var, ConsId, InstMap, InstMap1, 
		ModuleInfo0, ModuleInfo1),
	goal_to_conj_list(Goal, GoalList),
	pd_util__examine_branch(ModuleInfo1, ProcArgInfo, BranchNo, GoalList,
		InstMap1, Vars0, Vars1),
	NextBranch is BranchNo + 1,
	pd_util__examine_case_list(ModuleInfo1, ProcArgInfo, NextBranch,
		Var, Goals, InstMap, Vars1, Vars, ModuleInfo).

:- pred pd_util__examine_branch(module_info::in, pd_arg_info::in, int::in,
		list(hlds_goal)::in, instmap::in, branch_info_map(var)::in,
		branch_info_map(var)::out) is det.

pd_util__examine_branch(_, _, _, [], _, Vars, Vars).
pd_util__examine_branch(ModuleInfo, ProcArgInfo, BranchNo, 
		[Goal | Goals], InstMap, Vars0, Vars) :-
	( Goal = call(PredId, ProcId, Args, _, _, _) - _ ->
		( 
			map__search(ProcArgInfo, proc(PredId, ProcId), 
				ThisProcArgInfo) 
		->
			pd_util__convert_branch_info(ThisProcArgInfo, 
				Args, BranchInfo),
			BranchInfo = pd_branch_info(Vars1, _, _),
			map__keys(Vars1, ExtraVars1),
			combine_vars(Vars0, BranchNo, ExtraVars1, Vars3)
		;
			Vars3 = Vars0
		)
	; 
		set__init(LeftVars0),
		map__init(Vars1),
		pd_util__get_branch_vars_goal_2(ModuleInfo, [Goal], no, 
			InstMap, LeftVars0, _, Vars1, Vars2)
	->
		map__keys(Vars2, ExtraVars2),
		combine_vars(Vars0, BranchNo, ExtraVars2, Vars3)
	;	
		Vars3 = Vars0
	),
	Goal = _ - GoalInfo,
	goal_info_get_instmap_delta(GoalInfo, InstMapDelta),
	instmap__apply_instmap_delta(InstMap, InstMapDelta, InstMap1),
	pd_util__examine_branch(ModuleInfo, ProcArgInfo, BranchNo,
		Goals, InstMap1, Vars3, Vars).

:- pred combine_vars(branch_info_map(var)::in, int::in, list(var)::in,
		branch_info_map(var)::out) is det.

combine_vars(Vars, _, [], Vars).
combine_vars(Vars0, BranchNo, [ExtraVar | ExtraVars], Vars) :-
	( map__search(Vars0, ExtraVar, Branches0) ->
		set__insert(Branches0, BranchNo, Branches),
		map__det_update(Vars0, ExtraVar, Branches, Vars1)
	;
		set__singleton_set(Branches, BranchNo),
		map__det_insert(Vars0, ExtraVar, Branches, Vars1)
	),
	combine_vars(Vars1, BranchNo, ExtraVars, Vars).

%-----------------------------------------------------------------------------%

pd_util__requantify_goal(Goal0, NonLocals, Goal) -->
	pd_info_get_proc_info(ProcInfo0),
	{ proc_info_varset(ProcInfo0, VarSet0) },
	{ proc_info_vartypes(ProcInfo0, VarTypes0) },
	{ implicitly_quantify_goal(Goal0, VarSet0, VarTypes0, NonLocals,
			Goal, VarSet, VarTypes, _) },
	{ proc_info_set_varset(ProcInfo0, VarSet, ProcInfo1) },
	{ proc_info_set_vartypes(ProcInfo1, VarTypes, ProcInfo) },
	pd_info_set_proc_info(ProcInfo).

pd_util__recompute_instmap_delta(Goal0, Goal) -->
	pd_info_get_module_info(ModuleInfo0),
	pd_info_get_instmap(InstMap),
	{ recompute_instmap_delta(yes, Goal0, Goal, InstMap, 
		ModuleInfo0, ModuleInfo) },
	pd_info_set_module_info(ModuleInfo).

%-----------------------------------------------------------------------------%

	% inst_msg(InstA, InstB, InstC):
	%       The information in InstC is the minimum of the
	%       information in InstA and InstB.  Where InstA and
	%       InstB specify a binding (free or bound), it must be
	%       the same in both.
	% 	Round off bindings to different constructors to ground.
	%	When in doubt, fail. This will only result in less 
	% 	optimization, not loss of correctness.

inst_msg(InstA, InstB, ModuleInfo, Inst) :-
	( InstA = InstB ->
		Inst = InstA
	;
		inst_expand(ModuleInfo, InstA, InstA2),
		inst_expand(ModuleInfo, InstB, InstB2),
		( InstB2 = not_reached ->
			Inst = InstA2
		;
			inst_msg_2(InstA2, InstB2, ModuleInfo, Inst)
		)
	).

:- pred inst_msg_2(inst, inst, module_info, inst).
:- mode inst_msg_2(in, in, in, out) is semidet.

inst_msg_2(any(_), any(Uniq), _, any(Uniq)).
inst_msg_2(free, free, _M, free).

inst_msg_2(bound(_, ListA), bound(UniqB, ListB), ModuleInfo, Inst) :-
	bound_inst_list_msg(ListA, ListB, ModuleInfo, UniqB, ListB, Inst).
inst_msg_2(bound(_, _), ground(UniqB, InfoB), _, ground(UniqB, InfoB)).

	% fail here, since the increasing inst size could 
	% cause termination problems for deforestation.
inst_msg_2(ground(_, _), bound(_UniqB, _ListB), _, _) :- fail.
inst_msg_2(ground(_, _), ground(UniqB, InfoB), _, ground(UniqB, InfoB)). 
inst_msg_2(abstract_inst(Name, ArgsA), abstract_inst(Name, ArgsB),
		ModuleInfo, abstract_inst(Name, Args)) :-
	inst_list_msg(ArgsA, ArgsB, ModuleInfo, Args).
inst_msg_2(not_reached, Inst, _, Inst).

:- pred inst_list_msg(list(inst), list(inst), module_info, list(inst)).
:- mode inst_list_msg(in, in, in, out) is semidet.

inst_list_msg([], [], _ModuleInfo, []).
inst_list_msg([ArgA | ArgsA], [ArgB | ArgsB], ModuleInfo, [Arg | Args]) :-
	inst_msg(ArgA, ArgB, ModuleInfo, Arg),
	inst_list_msg(ArgsA, ArgsB, ModuleInfo, Args).

	% bound_inst_list_msg(Xs, Ys, ModuleInfo, Zs):
	% The two input lists Xs and Ys must already be sorted.
	% If any of the functors in Xs are not in Ys or vice
	% versa, the final inst is ground, unless either of the insts
	% contains any or the insts are the insts are not uniformly 
	% unique (or non-unique), in which case we fail, since 
	% the msg operation could introduce mode errors. 
	% Otherwise, the take the msg of the argument insts.

:- pred bound_inst_list_msg(list(bound_inst), list(bound_inst),
		module_info, uniqueness, list(bound_inst), inst).
:- mode bound_inst_list_msg(in, in, in, in, in, out) is semidet.

bound_inst_list_msg(Xs, Ys, ModuleInfo, Uniq, List, Inst) :-
	(
		Xs = [],
		Ys = []
	->
		Inst = bound(Uniq, [])
	;	
		Xs = [X | Xs1],
		Ys = [Y | Ys1],
		X = functor(ConsId, ArgsX),
		Y = functor(ConsId, ArgsY)
	->
		inst_list_msg(ArgsX, ArgsY, ModuleInfo, Args),
		Z = functor(ConsId, Args),
		bound_inst_list_msg(Xs1, Ys1, ModuleInfo, Uniq, List, Inst1),
		( Inst1 = bound(Uniq, Zs) ->
			Inst = bound(Uniq, [Z | Zs])
		;
			Inst = Inst1
		)
	;
		% Check that it's OK to round off the uniqueness information.
		( 
			Uniq = shared,
			inst_is_ground(ModuleInfo, bound(shared, List)),
			inst_is_not_partly_unique(ModuleInfo, 
				bound(shared, List))
		;
			Uniq = unique,
			inst_is_unique(ModuleInfo, bound(unique, List))
		),		
		Inst = ground(Uniq, no)
	).

%-----------------------------------------------------------------------------%

pd_util__inst_size(ModuleInfo, Inst, Size) :-
	set__init(Expansions),
	pd_util__inst_size_2(ModuleInfo, Inst, Expansions, Size).

:- pred pd_util__inst_size_2(module_info::in, (inst)::in,
		set(inst_name)::in, int::out) is det.

pd_util__inst_size_2(_, not_reached, _, 0).
pd_util__inst_size_2(_, any(_), _, 0).
pd_util__inst_size_2(_, free, _, 0).
pd_util__inst_size_2(_, free(_), _, 0).
pd_util__inst_size_2(_, ground(_, _), _, 0).
pd_util__inst_size_2(_, inst_var(_), _, 0).
pd_util__inst_size_2(_, abstract_inst(_, _), _, 0).
pd_util__inst_size_2(ModuleInfo, defined_inst(InstName), Expansions0, Size) :-
	( set__member(InstName, Expansions0) ->
		Size = 1
	;
		set__insert(Expansions0, InstName, Expansions),
		inst_lookup(ModuleInfo, InstName, Inst),
		pd_util__inst_size_2(ModuleInfo, Inst, Expansions, Size)
	).
pd_util__inst_size_2(ModuleInfo, bound(_, Functors), Expansions, Size) :-
	pd_util__bound_inst_size(ModuleInfo, Functors, Expansions, 1, Size).

:- pred pd_util__bound_inst_size(module_info::in, list(bound_inst)::in,
		set(inst_name)::in, int::in, int::out) is det.
		
pd_util__bound_inst_size(_, [], _, Size, Size).
pd_util__bound_inst_size(ModuleInfo, [functor(_, ArgInsts) | Insts],
		Expansions, Size0, Size) :-
	pd_util__inst_list_size(ModuleInfo, ArgInsts,
		Expansions, Size0, Size1),
	Size2 is Size1 + 1,
	pd_util__bound_inst_size(ModuleInfo, Insts, Expansions, Size2, Size).

pd_util__inst_list_size(_, [], _, Size, Size).
pd_util__inst_list_size(ModuleInfo, [Inst | Insts],
		Expansions, Size0, Size) :-
	pd_util__inst_size_2(ModuleInfo, Inst, Expansions, Size1),
	Size2 is Size0 + Size1,
	pd_util__inst_list_size(ModuleInfo, Insts, Expansions, Size2, Size).

%-----------------------------------------------------------------------------%

pd_util__goals_match(_ModuleInfo, OldGoal, OldArgs, OldArgTypes,
		NewGoal, NewVarTypes, OldNewRenaming, TypeSubn) :-

	goal_to_conj_list(OldGoal, OldGoalList),
	goal_to_conj_list(NewGoal, NewGoalList),
	map__init(OldNewRenaming0),
	pd_util__goals_match_2(OldGoalList, NewGoalList,
		OldNewRenaming0, OldNewRenaming),

	%
	% Check that the goal produces a superset of the outputs of the
	% version we are searching for. 
	%
	Search = lambda([K1::in, V1::out] is semidet,
			map__search(OldNewRenaming, K1, V1)),
	list__map(Search, OldArgs, NewArgs),
	NewGoal = _ - NewGoalInfo,
	goal_info_get_nonlocals(NewGoalInfo, NewNonLocals),
	set__delete_list(NewNonLocals, NewArgs, UnmatchedNonLocals),
	set__empty(UnmatchedNonLocals),
	
	% Check that argument types of NewGoal are subsumed by 
	% those of OldGoal.
	pd_util__collect_matching_arg_types(OldArgs, OldArgTypes, 
		OldNewRenaming, [], MatchingArgTypes),
	map__apply_to_list(NewArgs, NewVarTypes, NewArgTypes),
	type_list_subsumes(MatchingArgTypes, NewArgTypes, TypeSubn).

:- pred pd_util__collect_matching_arg_types(list(var)::in, list(type)::in,
		map(var, var)::in, list(type)::in, list(type)::out) is det.

pd_util__collect_matching_arg_types([], [], _, Types0, Types) :-
	list__reverse(Types0, Types).
pd_util__collect_matching_arg_types([_|_], [], _, _, _) :-
	error("pd_util__collect_matching_arg_types").
pd_util__collect_matching_arg_types([], [_|_], _, _, _) :-
	error("pd_util__collect_matching_arg_types").
pd_util__collect_matching_arg_types([Arg | Args], [Type | Types], 
		Renaming, MatchingTypes0, MatchingTypes) :-
	( map__contains(Renaming, Arg) ->
		MatchingTypes1 = [Type | MatchingTypes0]
	;
		MatchingTypes1 = MatchingTypes0
	),
	pd_util__collect_matching_arg_types(Args, Types, 
		Renaming, MatchingTypes1, MatchingTypes).

:- pred pd_util__goals_match_2(list(hlds_goal)::in,
		list(hlds_goal)::in, map(var, var)::in,
		map(var, var)::out) is semidet.

pd_util__goals_match_2([], [], R, R).
pd_util__goals_match_2([OldGoal | OldGoals], [NewGoal | NewGoals],
		ONRenaming0, ONRenaming) :-	
	(
		(
			OldGoal = unify(_, _, _, OldUnification, _) - _,
			NewGoal = unify(_, _, _, NewUnification, _) - _,
			(
				OldUnification = simple_test(OldVar1, OldVar2),
				NewUnification = simple_test(NewVar1, NewVar2),
				OldArgs = [OldVar1, OldVar2],
				NewArgs = [NewVar1, NewVar2]
			;
				OldUnification = assign(OldVar1, OldVar2),
				NewUnification = assign(NewVar1, NewVar2),
				OldArgs = [OldVar1, OldVar2],
				NewArgs = [NewVar1, NewVar2]
			;
				OldUnification = construct(OldVar, ConsId, 
						OldArgs1, _),
				NewUnification = construct(NewVar, ConsId, 
						NewArgs1,_ ),
				OldArgs = [OldVar | OldArgs1],
				NewArgs = [NewVar | NewArgs1]
			;
				OldUnification = deconstruct(OldVar, ConsId,
							OldArgs1, _, _),
				NewUnification = deconstruct(NewVar, ConsId,
							NewArgs1, _, _),
				OldArgs = [OldVar | OldArgs1],
				NewArgs = [NewVar | NewArgs1]
			)	
		;
			OldGoal = call(PredId, ProcId, OldArgs, _, _, _) - _,
			NewGoal = call(PredId, ProcId, NewArgs, _, _, _) - _
		;
			OldGoal = higher_order_call(OldVar, OldArgs1, Types,
					Modes, Det, PredOrFunc) - _,
			NewGoal = higher_order_call(NewVar, NewArgs1, Types,
					Modes, Det, PredOrFunc) - _,
			OldArgs = [OldVar | OldArgs1],
			NewArgs = [NewVar | NewArgs1]
		)
	->
		assoc_list__from_corresponding_lists(OldArgs, 
			NewArgs, ONArgsList),
		MapInsert =
			lambda([KeyValue::in, Map0::in, Map::out] is semidet, (
				KeyValue = Key - Value,
				( map__search(Map0, Key, Value0) ->
					Value = Value0,
					Map = Map0
				;
					map__det_insert(Map0, Key, Value, Map)
				)
			)),
		list__foldl(MapInsert, ONArgsList, ONRenaming0, ONRenaming1)
	;
		(
			OldGoal = not(OldSubGoal) - _,
			NewGoal = not(NewSubGoal) - _
		;
			OldGoal = some(_, OldSubGoal) - _,
			NewGoal = some(_, NewSubGoal) - _
		)
	->
		goal_to_conj_list(OldSubGoal, OldSubGoalList),
		goal_to_conj_list(NewSubGoal, NewSubGoalList),
		pd_util__goals_match_2(OldSubGoalList, NewSubGoalList,
			ONRenaming0, ONRenaming1)
	;
		fail
	),
	pd_util__goals_match_2(OldGoals, NewGoals, 
		ONRenaming1, ONRenaming).

%-----------------------------------------------------------------------------%

pd_util__can_reorder_goals(ModuleInfo, FullyStrict, EarlierGoal, LaterGoal) :-
	EarlierGoal = _ - EarlierGoalInfo,
	LaterGoal = _ - LaterGoalInfo,

		% Impure goals cannot be reordered.
	goal_info_is_pure(EarlierGoalInfo),
	goal_info_is_pure(LaterGoalInfo),

	pd_util__reordering_maintains_termination(ModuleInfo, FullyStrict, 
		EarlierGoal, LaterGoal),

	%
	% Don't reorder the goals if the later goal depends
	% on the outputs of the current goal.
	%
	\+ goal_depends_on_goal(EarlierGoal, LaterGoal),

	%
	% Don't reorder the goals if the later goal changes the 
	% instantiatedness of any of the non-locals of the earlier
	% goal. This is necessary if the later goal clobbers any 
	% of the non-locals of the earlier goal, and avoids rerunning
	% full mode analysis in other cases.
	%
	\+ goal_depends_on_goal(LaterGoal, EarlierGoal).

:- pred goal_depends_on_goal(hlds_goal::in, hlds_goal::in) is semidet.

goal_depends_on_goal(_ - GoalInfo1, _ - GoalInfo2) :-
	goal_info_get_instmap_delta(GoalInfo1, InstmapDelta1),
	instmap_delta_changed_vars(InstmapDelta1, ChangedVars1),
	goal_info_get_nonlocals(GoalInfo2, NonLocals2),
	set__intersect(ChangedVars1, NonLocals2, Intersection),
	\+ set__empty(Intersection).
	
pd_util__reordering_maintains_termination(ModuleInfo, FullyStrict, 
		EarlierGoal, LaterGoal) :-
	EarlierGoal = _ - EarlierGoalInfo,
	LaterGoal = _ - LaterGoalInfo,

	goal_info_get_determinism(EarlierGoalInfo, EarlierDetism),
	determinism_components(EarlierDetism, EarlierCanFail, _),
	goal_info_get_determinism(LaterGoalInfo, LaterDetism),
	determinism_components(LaterDetism, LaterCanFail, _),

		% If --fully-strict was specified, don't convert 
		% (can_loop, can_fail) into (can_fail, can_loop). 
	( 
		FullyStrict = yes, 
		\+ code_aux__goal_cannot_loop(ModuleInfo, EarlierGoal)
	->
		LaterCanFail = cannot_fail
	;
		true
	),
		% Don't convert (can_fail, can_loop) into 
		% (can_loop, can_fail), since this could worsen 
		% the termination properties of the program.
	( EarlierCanFail = can_fail ->
		code_aux__goal_cannot_loop(ModuleInfo, LaterGoal)
	;
		true
	).

%-----------------------------------------------------------------------------%
%-----------------------------------------------------------------------------%



More information about the developers mailing list