[m-rev.] Extend solver initialisation to cover disjunctions and if-then-elses

Ralph Becket rafe at cs.mu.OZ.AU
Fri Mar 18 11:40:28 AEDT 2005


Estimated hours taken: 12
Branches: main, release

Fix a bug.

Ensure that if a solver variable is initialised by some branches of a
disjunction, then it is initialised by all branches.  Make a similar
guarantee for the branches of if-then-else goals.

compiler/modecheck_unify.m:
	Move a call to mode_info_get_var_types so that it retrieves the
	VarTypes map at the right point.  I had moved this call to the
	start of the clause, but it turns out that new prog_vars may be
	introduced in the body of the clause and it is important to
	have the up-to-date VarTypes including these variables.  My
	earlier change broke the hard_coded/unify_existq_cons test case,
	which is now fixed.

compiler/modes.m:
	Change modecheck_goal_expr to handle solver variables that should
	be initialised at the ends of disjunction branches or if-then-else
	branches.

	Added various support predicates and functions.

	Changed modecheck_disj_list so that it does not flatten nested
	disjunctions.  This ensures that the number of disjuncts and the
	number of instmaps returned by modecheck_disj_list are the same.
	A new function, flatten_disjs, is then used to flatten nested
	disjunctions into one.

tests/hard_coded/Mmakefile:
tests/hard_coded/solver_disj_inits.m:
tests/hard_coded/solver_disj_inits.exp:
tests/hard_coded/solver_ite_inits.m:
tests/hard_coded/solver_ite_inits.exp:
	Added test cases.

Index: compiler/modecheck_unify.m
===================================================================
RCS file: /home/mercury1/repository/mercury/compiler/modecheck_unify.m,v
retrieving revision 1.75
diff -u -r1.75 modecheck_unify.m
--- compiler/modecheck_unify.m	15 Mar 2005 02:51:19 -0000	1.75
+++ compiler/modecheck_unify.m	17 Mar 2005 06:23:02 -0000
@@ -458,7 +458,6 @@
 		Unification0, UnifyContext, GoalInfo0, Goal, !ModeInfo, !IO) :-
 	mode_info_get_module_info(!.ModeInfo, ModuleInfo0),
 	mode_info_get_how_to_check(!.ModeInfo, HowToCheckGoal),
-	mode_info_get_var_types(!.ModeInfo, VarTypes),
 
 	%
 	% Fully module qualify all cons_ids
@@ -492,6 +491,8 @@
 		mode_info_var_is_live(!.ModeInfo, X, LiveX),
 		ExtraGoals0 = no_extra_goals
 	),
+
+	mode_info_get_var_types(!.ModeInfo, VarTypes),
 
 	(
 		% If we are allowed to insert solver type initialisation
Index: compiler/modes.m
===================================================================
RCS file: /home/mercury1/repository/mercury/compiler/modes.m,v
retrieving revision 1.294
diff -u -r1.294 modes.m
--- compiler/modes.m	15 Mar 2005 02:51:20 -0000	1.294
+++ compiler/modes.m	17 Mar 2005 05:10:18 -0000
@@ -1110,17 +1110,23 @@
 	instmap__unify(NonLocals, InstMapNonlocalList, !ModeInfo),
 	mode_checkpoint(exit, "par_conj", !ModeInfo, !IO).
 
-modecheck_goal_expr(disj(List0), GoalInfo0, Goal, !ModeInfo, !IO) :-
+modecheck_goal_expr(disj(Disjs0), GoalInfo0, Goal, !ModeInfo, !IO) :-
 	mode_checkpoint(enter, "disj", !ModeInfo, !IO),
-	( List0 = [] ->	% for efficiency, optimize common case
-		Goal = disj(List0),
+	( Disjs0 = [] ->	% for efficiency, optimize common case
+		Goal = disj(Disjs0),
 		instmap__init_unreachable(InstMap),
 		mode_info_set_instmap(InstMap, !ModeInfo)
 	;
 		goal_info_get_nonlocals(GoalInfo0, NonLocals),
-		modecheck_disj_list(List0, List, InstMapList, !ModeInfo, !IO),
+		modecheck_disj_list(Disjs0, Disjs1, InstMapList0,
+			!ModeInfo, !IO),
+		mode_info_get_var_types(!.ModeInfo, VarTypes),
+		handle_solver_vars_in_disjs(set__to_sorted_list(NonLocals),
+			VarTypes, Disjs1, Disjs2, InstMapList0, InstMapList,
+			!ModeInfo),
+		Disjs = flatten_disjs(Disjs2),
 		instmap__merge(NonLocals, InstMapList, disj, !ModeInfo),
-		disj_list_to_goal(List, GoalInfo0, Goal - _GoalInfo)
+		disj_list_to_goal(Disjs, GoalInfo0, Goal - _GoalInfo)
 	),
 	mode_checkpoint(exit, "disj", !ModeInfo, !IO).
 
@@ -1141,18 +1147,23 @@
 	mode_info_remove_live_vars(ThenVars, !ModeInfo),
 	mode_info_unlock_vars(if_then_else, NonLocals, !ModeInfo),
 	( instmap__is_reachable(InstMapCond) ->
-		modecheck_goal(Then0, Then, !ModeInfo, !IO),
-		mode_info_get_instmap(!.ModeInfo, InstMapThen)
+		modecheck_goal(Then0, Then1, !ModeInfo, !IO),
+		mode_info_get_instmap(!.ModeInfo, InstMapThen1)
 	;
 		% We should not mode-analyse the goal, since it is unreachable.
 		% Instead we optimize the goal away, so that later passes
 		% won't complain about it not having mode information.
-		true_goal(Then),
-		InstMapThen = InstMapCond
+		true_goal(Then1),
+		InstMapThen1 = InstMapCond
 	),
 	mode_info_set_instmap(InstMap0, !ModeInfo),
-	modecheck_goal(Else0, Else, !ModeInfo, !IO),
-	mode_info_get_instmap(!.ModeInfo, InstMapElse),
+	modecheck_goal(Else0, Else1, !ModeInfo, !IO),
+	mode_info_get_instmap(!.ModeInfo, InstMapElse1),
+	mode_info_get_var_types(!.ModeInfo, VarTypes),
+	handle_solver_vars_in_ite(set__to_sorted_list(NonLocals), VarTypes,
+		Then1, Then, Else1, Else,
+		InstMapThen1, InstMapThen, InstMapElse1, InstMapElse,
+		!ModeInfo),
 	mode_info_set_instmap(InstMap0, !ModeInfo),
 	instmap__merge(NonLocals, [InstMapThen, InstMapElse], if_then_else,
 		!ModeInfo),
@@ -1448,6 +1459,164 @@
 
 %-----------------------------------------------------------------------------%
 
+	% Ensure that any non-local solver var that is initialised in
+	% one disjunct is initialised in all disjuncts.
+	%
+:- pred handle_solver_vars_in_disjs(list(prog_var)::in,
+	map(prog_var, (type))::in, list(hlds_goal)::in, list(hlds_goal)::out,
+	list(instmap)::in, list(instmap)::out, mode_info::in, mode_info::out)
+	is det.
+
+handle_solver_vars_in_disjs(NonLocals, VarTypes, Disjs0, Disjs,
+		InstMaps0, InstMaps, !ModeInfo) :-
+	mode_info_get_module_info(!.ModeInfo, ModuleInfo),
+	EnsureInitialised =
+		solver_vars_that_must_be_initialised(NonLocals, VarTypes,
+		ModuleInfo, InstMaps0),
+	add_necessary_disj_init_calls(Disjs0, Disjs, InstMaps0, InstMaps,
+		EnsureInitialised, !ModeInfo).
+
+
+:- pred handle_solver_vars_in_ite(list(prog_var)::in,
+	map(prog_var, (type))::in,
+	hlds_goal::in, hlds_goal::out, hlds_goal::in, hlds_goal::out,
+	instmap::in, instmap::out, instmap::in, instmap::out, mode_info::in,
+	mode_info::out) is det.
+
+handle_solver_vars_in_ite(NonLocals, VarTypes,
+		Then0, Then, Else0, Else,
+		ThenInstMap0, ThenInstMap, ElseInstMap0, ElseInstMap,
+		!ModeInfo) :-
+
+	mode_info_get_module_info(!.ModeInfo, ModuleInfo),
+	EnsureInitialised =
+		solver_vars_that_must_be_initialised(NonLocals, VarTypes,
+		ModuleInfo, [ThenInstMap0, ElseInstMap0]),
+
+	ThenVarsToInit =
+		solver_vars_to_init(EnsureInitialised, ModuleInfo,
+		ThenInstMap0),
+	construct_initialisation_calls(ThenVarsToInit, ThenInitCalls,
+		!ModeInfo),
+	Then = append_init_calls_to_goal(ThenInitCalls, Then0),
+	ThenInstMap = set_vars_to_inst_any(ThenVarsToInit, ThenInstMap0),
+
+	ElseVarsToInit =
+		solver_vars_to_init(EnsureInitialised, ModuleInfo,
+		ElseInstMap0),
+	construct_initialisation_calls(ElseVarsToInit, ElseInitCalls,
+		!ModeInfo),
+	Else = append_init_calls_to_goal(ElseInitCalls, Else0),
+	ElseInstMap = set_vars_to_inst_any(ElseVarsToInit, ElseInstMap0).
+
+
+:- func solver_vars_that_must_be_initialised(list(prog_var),
+	map(prog_var, (type)), module_info, list(instmap)) = list(prog_var).
+
+solver_vars_that_must_be_initialised([], _VarTypes, _ModuleInfo, _InstMaps) =
+	[].
+
+solver_vars_that_must_be_initialised([Var | Vars], VarTypes, ModuleInfo,
+		InstMaps) =
+	( if
+		VarType = VarTypes ^ elem(Var),
+		type_util__type_is_solver_type(ModuleInfo, VarType),
+		list__member(InstMap, InstMaps),
+		instmap__lookup_var(InstMap, Var, Inst),
+		not inst_match__inst_is_free(ModuleInfo, Inst)
+	  then
+	  	[ Var | solver_vars_that_must_be_initialised(Vars, VarTypes,
+			ModuleInfo, InstMaps) ]
+	  else
+	  	solver_vars_that_must_be_initialised(Vars, VarTypes,
+			ModuleInfo, InstMaps)
+	).
+
+
+:- pred add_necessary_disj_init_calls(list(hlds_goal)::in,
+	list(hlds_goal)::out, list(instmap)::in, list(instmap)::out,
+	list(prog_var)::in, mode_info::in, mode_info::out) is det.
+
+add_necessary_disj_init_calls([], [], [], [], _EnsureInitialised, !ModeInfo).
+
+add_necessary_disj_init_calls([], _, [_ | _], _, _, _, _) :-
+	error("modes.add_necessary_init_calls: mismatched lists").
+
+add_necessary_disj_init_calls([_ | _], _, [], _, _, _, _) :-
+	error("modes.add_necessary_init_calls: mismatched lists").
+
+add_necessary_disj_init_calls([Goal0 | Goals0], [Goal | Goals],
+		[InstMap0 | InstMaps0], [InstMap | InstMaps],
+		EnsureInitialised, !ModeInfo) :-
+	mode_info_get_module_info(!.ModeInfo, ModuleInfo),
+	VarsToInit =
+		solver_vars_to_init(EnsureInitialised, ModuleInfo, InstMap0),
+	construct_initialisation_calls(VarsToInit, InitCalls, !ModeInfo),
+	Goal = append_init_calls_to_goal(InitCalls, Goal0),
+	InstMap = set_vars_to_inst_any(VarsToInit, InstMap0),
+	add_necessary_disj_init_calls(Goals0, Goals, InstMaps0, InstMaps,
+		EnsureInitialised, !ModeInfo).
+
+
+:- func append_init_calls_to_goal(list(hlds_goal), hlds_goal) = hlds_goal.
+
+append_init_calls_to_goal(InitCalls, Goal0) = Goal :-
+	Goal0 = GoalExpr - GoalInfo,
+	(
+		GoalExpr = disj(Disjs0)
+	->
+		Disjs = list__map(append_init_calls_to_goal(InitCalls),
+				Disjs0),
+		Goal  = disj(Disjs) - GoalInfo
+	;
+		goal_to_conj_list(Goal0, Conjs),
+		conj_list_to_goal(Conjs ++ InitCalls, GoalInfo, Goal)
+	).
+
+
+:- func flatten_disjs(list(hlds_goal)) = list(hlds_goal).
+
+flatten_disjs(Disjs) = list__foldr(flatten_disj, Disjs, []).
+
+
+:- func flatten_disj(hlds_goal, list(hlds_goal)) = list(hlds_goal).
+
+flatten_disj(Disj, Disjs0) = Disjs :-
+	(
+		Disj = disj(Disjs1) - _GoalInfo
+	->
+		Disjs = list__foldr(flatten_disj, Disjs1, Disjs0)
+	;
+		Disjs = [Disj | Disjs0]
+	).
+
+
+:- func solver_vars_to_init(list(prog_var), module_info, instmap) =
+	list(prog_var).
+
+solver_vars_to_init([], _ModuleInfo, _InstMap) = [].
+
+solver_vars_to_init([Var | Vars], ModuleInfo, InstMap) =
+	( if
+		instmap__lookup_var(InstMap, Var, Inst),
+		inst_match__inst_is_free(ModuleInfo, Inst)
+	  then
+	  	[Var | solver_vars_to_init(Vars, ModuleInfo, InstMap)]
+	  else
+	  	solver_vars_to_init(Vars, ModuleInfo, InstMap)
+	).
+
+
+:- func set_vars_to_inst_any(list(prog_var), instmap) = instmap.
+
+set_vars_to_inst_any([], InstMap) = InstMap.
+
+set_vars_to_inst_any([Var | Vars], InstMap0) = InstMap :-
+	instmap__set(InstMap0, Var, any_inst, InstMap1),
+	InstMap = set_vars_to_inst_any(Vars, InstMap1).
+
+%-----------------------------------------------------------------------------%
+
 	% Modecheck a conjunction without doing any reordering.
 	% This is used by handle_extra_goals above.
 :- pred modecheck_conj_list_no_delay(list(hlds_goal)::in, list(hlds_goal)::out,
@@ -2190,20 +2359,14 @@
 	io::di, io::uo) is det.
 
 modecheck_disj_list([], [], [], !ModeInfo, !IO).
-modecheck_disj_list([Goal0 | Goals0], Goals, [InstMap | InstMaps],
+modecheck_disj_list([Goal0 | Goals0], [Goal | Goals], [InstMap | InstMaps],
 		!ModeInfo, !IO) :-
 	mode_info_get_instmap(!.ModeInfo, InstMap0),
 	modecheck_goal(Goal0, Goal, !ModeInfo, !IO),
 	mode_info_get_instmap(!.ModeInfo, InstMap),
 	mode_info_set_instmap(InstMap0, !ModeInfo),
-	modecheck_disj_list(Goals0, Goals1, InstMaps, !ModeInfo, !IO),
-	%
-	% If Goal is a nested disjunction,
-	% then merge it with the outer disjunction.
-	% If Goal is `fail', this will delete it.
-	%
-	goal_to_disj_list(Goal, DisjList),
-	list__append(DisjList, Goals1, Goals).
+	modecheck_disj_list(Goals0, Goals, InstMaps, !ModeInfo, !IO).
+
 
 :- pred modecheck_case_list(list(case)::in, prog_var::in, list(case)::out,
 	list(instmap)::out, mode_info::in, mode_info::out,
Index: tests/hard_coded/Mmakefile
===================================================================
RCS file: /home/mercury1/repository/tests/hard_coded/Mmakefile,v
retrieving revision 1.251
diff -u -r1.251 Mmakefile
--- tests/hard_coded/Mmakefile	15 Mar 2005 02:51:25 -0000	1.251
+++ tests/hard_coded/Mmakefile	18 Mar 2005 00:38:05 -0000
@@ -154,6 +154,8 @@
 	shift_test \
 	solve_quadratic \
 	solver_construction_init_test \
+	solver_disj_inits \
+	solver_ite_inits \
 	space \
 	stable_sort \
 	string_alignment \
Index: tests/hard_coded/solver_disj_inits.exp
===================================================================
RCS file: tests/hard_coded/solver_disj_inits.exp
diff -N tests/hard_coded/solver_disj_inits.exp
--- /dev/null	1 Jan 1970 00:00:00 -0000
+++ tests/hard_coded/solver_disj_inits.exp	18 Mar 2005 00:35:13 -0000
@@ -0,0 +1,3 @@
+0
+1
+2
Index: tests/hard_coded/solver_disj_inits.m
===================================================================
RCS file: tests/hard_coded/solver_disj_inits.m
diff -N tests/hard_coded/solver_disj_inits.m
--- /dev/null	1 Jan 1970 00:00:00 -0000
+++ tests/hard_coded/solver_disj_inits.m	18 Mar 2005 00:35:01 -0000
@@ -0,0 +1,74 @@
+%-----------------------------------------------------------------------------%
+% solver_disj_inits.m
+% Ralph Becket <rafe at cs.mu.oz.au>
+% Fri Mar 18 11:17:41 EST 2005
+% vim: ft=mercury ts=4 sw=4 et wm=0 tw=0
+%
+% Test that the compiler inserts solver variable initialisation calls
+% at the ends of disjuncts if necessary to ensure that solver variables
+% have compatible insts at the end of a disjunction.
+%
+%-----------------------------------------------------------------------------%
+
+:- module solver_disj_inits.
+
+:- interface.
+
+:- import_module io.
+
+
+
+:- pred main(io :: di, io :: uo) is det.
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
+
+:- implementation.
+
+:- import_module int.
+
+
+
+:- solver type foo
+    where   representation is int,
+            initialisation is init,
+            ground         is ground,
+            any            is ground.
+
+
+
+:- pred init(foo::oa) is det.
+:- pragma promise_pure(init/1).
+init(X) :- impure X = 'representation to any foo/0'(0).
+
+:- func foo(int::in) = (foo::oa) is det.
+:- pragma promise_pure(foo/1).
+foo(N) = X :- impure X = 'representation to any foo/0'(N).
+
+:- pred write_foo(foo::ia, io::di, io::uo) is det.
+:- pragma promise_pure(write_foo/3).
+write_foo(Foo, !IO) :-
+    impure X = 'representation of any foo/0'(Foo),
+    io.print(X, !IO),
+    io.nl(!IO).
+
+
+
+:- type bar ---> a ; b ; c.
+
+:- func f(bar::in) = (foo::oa) is det.
+f(Bar) = Foo :-
+    (   Bar = a
+    ;   Bar = b, Foo = foo(1)
+    ;   Bar = c, Foo = foo(2)
+    ).
+
+%-----------------------------------------------------------------------------%
+
+main(!IO) :-
+    write_foo(f(a), !IO),
+    write_foo(f(b), !IO),
+    write_foo(f(c), !IO).
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
Index: tests/hard_coded/solver_ite_inits.exp
===================================================================
RCS file: tests/hard_coded/solver_ite_inits.exp
diff -N tests/hard_coded/solver_ite_inits.exp
--- /dev/null	1 Jan 1970 00:00:00 -0000
+++ tests/hard_coded/solver_ite_inits.exp	18 Mar 2005 00:37:46 -0000
@@ -0,0 +1,3 @@
+0
+1
+1
Index: tests/hard_coded/solver_ite_inits.m
===================================================================
RCS file: tests/hard_coded/solver_ite_inits.m
diff -N tests/hard_coded/solver_ite_inits.m
--- /dev/null	1 Jan 1970 00:00:00 -0000
+++ tests/hard_coded/solver_ite_inits.m	18 Mar 2005 00:37:27 -0000
@@ -0,0 +1,71 @@
+%-----------------------------------------------------------------------------%
+% solver_ite_inits.m
+% Ralph Becket <rafe at cs.mu.oz.au>
+% Fri Mar 18 11:17:41 EST 2005
+% vim: ft=mercury ts=4 sw=4 et wm=0 tw=0
+%
+% Test that the compiler inserts solver variable initialisation calls
+% at the ends of if-then-else branches if necessary to ensure that solver
+% variables have compatible insts at the end of the if-then-else.
+%
+%-----------------------------------------------------------------------------%
+
+:- module solver_ite_inits.
+
+:- interface.
+
+:- import_module io.
+
+
+
+:- pred main(io :: di, io :: uo) is det.
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
+
+:- implementation.
+
+:- import_module int.
+
+
+
+:- solver type foo
+    where   representation is int,
+            initialisation is init,
+            ground         is ground,
+            any            is ground.
+
+
+
+:- pred init(foo::oa) is det.
+:- pragma promise_pure(init/1).
+init(X) :- impure X = 'representation to any foo/0'(0).
+
+:- func foo(int::in) = (foo::oa) is det.
+:- pragma promise_pure(foo/1).
+foo(N) = X :- impure X = 'representation to any foo/0'(N).
+
+:- pred write_foo(foo::ia, io::di, io::uo) is det.
+:- pragma promise_pure(write_foo/3).
+write_foo(Foo, !IO) :-
+    impure X = 'representation of any foo/0'(Foo),
+    io.print(X, !IO),
+    io.nl(!IO).
+
+
+
+:- type bar ---> a ; b ; c.
+
+:- func f(bar::in) = (foo::oa) is det.
+f(Bar) = Foo :-
+    ( if Bar = a then true else Foo = foo(1) ).
+
+%-----------------------------------------------------------------------------%
+
+main(!IO) :-
+    write_foo(f(a), !IO),
+    write_foo(f(b), !IO),
+    write_foo(f(c), !IO).
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
--------------------------------------------------------------------------
mercury-reviews mailing list
post:  mercury-reviews at cs.mu.oz.au
administrative address: owner-mercury-reviews at cs.mu.oz.au
unsubscribe: Address: mercury-reviews-request at cs.mu.oz.au Message: unsubscribe
subscribe:   Address: mercury-reviews-request at cs.mu.oz.au Message: subscribe
--------------------------------------------------------------------------



More information about the reviews mailing list