[m-rev.] Extend solver initialisation to cover disjunctions and if-then-elses
Ralph Becket
rafe at cs.mu.OZ.AU
Fri Mar 18 11:40:28 AEDT 2005
Estimated hours taken: 12
Branches: main, release
Fix a bug.
Ensure that if a solver variable is initialised by some branches of a
disjunction, then it is initialised by all branches. Make a similar
guarantee for the branches of if-then-else goals.
compiler/modecheck_unify.m:
Move a call to mode_info_get_var_types so that it retrieves the
VarTypes map at the right point. I had moved this call to the
start of the clause, but it turns out that new prog_vars may be
introduced in the body of the clause and it is important to
have the up-to-date VarTypes including these variables. My
earlier change broke the hard_coded/unify_existq_cons test case,
which is now fixed.
compiler/modes.m:
Change modecheck_goal_expr to handle solver variables that should
be initialised at the ends of disjunction branches or if-then-else
branches.
Added various support predicates and functions.
Changed modecheck_disj_list so that it does not flatten nested
disjunctions. This ensures that the number of disjuncts and the
number of instmaps returned by modecheck_disj_list are the same.
A new function, flatten_disjs, is then used to flatten nested
disjunctions into one.
tests/hard_coded/Mmakefile:
tests/hard_coded/solver_disj_inits.m:
tests/hard_coded/solver_disj_inits.exp:
tests/hard_coded/solver_ite_inits.m:
tests/hard_coded/solver_ite_inits.exp:
Added test cases.
Index: compiler/modecheck_unify.m
===================================================================
RCS file: /home/mercury1/repository/mercury/compiler/modecheck_unify.m,v
retrieving revision 1.75
diff -u -r1.75 modecheck_unify.m
--- compiler/modecheck_unify.m 15 Mar 2005 02:51:19 -0000 1.75
+++ compiler/modecheck_unify.m 17 Mar 2005 06:23:02 -0000
@@ -458,7 +458,6 @@
Unification0, UnifyContext, GoalInfo0, Goal, !ModeInfo, !IO) :-
mode_info_get_module_info(!.ModeInfo, ModuleInfo0),
mode_info_get_how_to_check(!.ModeInfo, HowToCheckGoal),
- mode_info_get_var_types(!.ModeInfo, VarTypes),
%
% Fully module qualify all cons_ids
@@ -492,6 +491,8 @@
mode_info_var_is_live(!.ModeInfo, X, LiveX),
ExtraGoals0 = no_extra_goals
),
+
+ mode_info_get_var_types(!.ModeInfo, VarTypes),
(
% If we are allowed to insert solver type initialisation
Index: compiler/modes.m
===================================================================
RCS file: /home/mercury1/repository/mercury/compiler/modes.m,v
retrieving revision 1.294
diff -u -r1.294 modes.m
--- compiler/modes.m 15 Mar 2005 02:51:20 -0000 1.294
+++ compiler/modes.m 17 Mar 2005 05:10:18 -0000
@@ -1110,17 +1110,23 @@
instmap__unify(NonLocals, InstMapNonlocalList, !ModeInfo),
mode_checkpoint(exit, "par_conj", !ModeInfo, !IO).
-modecheck_goal_expr(disj(List0), GoalInfo0, Goal, !ModeInfo, !IO) :-
+modecheck_goal_expr(disj(Disjs0), GoalInfo0, Goal, !ModeInfo, !IO) :-
mode_checkpoint(enter, "disj", !ModeInfo, !IO),
- ( List0 = [] -> % for efficiency, optimize common case
- Goal = disj(List0),
+ ( Disjs0 = [] -> % for efficiency, optimize common case
+ Goal = disj(Disjs0),
instmap__init_unreachable(InstMap),
mode_info_set_instmap(InstMap, !ModeInfo)
;
goal_info_get_nonlocals(GoalInfo0, NonLocals),
- modecheck_disj_list(List0, List, InstMapList, !ModeInfo, !IO),
+ modecheck_disj_list(Disjs0, Disjs1, InstMapList0,
+ !ModeInfo, !IO),
+ mode_info_get_var_types(!.ModeInfo, VarTypes),
+ handle_solver_vars_in_disjs(set__to_sorted_list(NonLocals),
+ VarTypes, Disjs1, Disjs2, InstMapList0, InstMapList,
+ !ModeInfo),
+ Disjs = flatten_disjs(Disjs2),
instmap__merge(NonLocals, InstMapList, disj, !ModeInfo),
- disj_list_to_goal(List, GoalInfo0, Goal - _GoalInfo)
+ disj_list_to_goal(Disjs, GoalInfo0, Goal - _GoalInfo)
),
mode_checkpoint(exit, "disj", !ModeInfo, !IO).
@@ -1141,18 +1147,23 @@
mode_info_remove_live_vars(ThenVars, !ModeInfo),
mode_info_unlock_vars(if_then_else, NonLocals, !ModeInfo),
( instmap__is_reachable(InstMapCond) ->
- modecheck_goal(Then0, Then, !ModeInfo, !IO),
- mode_info_get_instmap(!.ModeInfo, InstMapThen)
+ modecheck_goal(Then0, Then1, !ModeInfo, !IO),
+ mode_info_get_instmap(!.ModeInfo, InstMapThen1)
;
% We should not mode-analyse the goal, since it is unreachable.
% Instead we optimize the goal away, so that later passes
% won't complain about it not having mode information.
- true_goal(Then),
- InstMapThen = InstMapCond
+ true_goal(Then1),
+ InstMapThen1 = InstMapCond
),
mode_info_set_instmap(InstMap0, !ModeInfo),
- modecheck_goal(Else0, Else, !ModeInfo, !IO),
- mode_info_get_instmap(!.ModeInfo, InstMapElse),
+ modecheck_goal(Else0, Else1, !ModeInfo, !IO),
+ mode_info_get_instmap(!.ModeInfo, InstMapElse1),
+ mode_info_get_var_types(!.ModeInfo, VarTypes),
+ handle_solver_vars_in_ite(set__to_sorted_list(NonLocals), VarTypes,
+ Then1, Then, Else1, Else,
+ InstMapThen1, InstMapThen, InstMapElse1, InstMapElse,
+ !ModeInfo),
mode_info_set_instmap(InstMap0, !ModeInfo),
instmap__merge(NonLocals, [InstMapThen, InstMapElse], if_then_else,
!ModeInfo),
@@ -1448,6 +1459,164 @@
%-----------------------------------------------------------------------------%
+ % Ensure that any non-local solver var that is initialised in
+ % one disjunct is initialised in all disjuncts.
+ %
+:- pred handle_solver_vars_in_disjs(list(prog_var)::in,
+ map(prog_var, (type))::in, list(hlds_goal)::in, list(hlds_goal)::out,
+ list(instmap)::in, list(instmap)::out, mode_info::in, mode_info::out)
+ is det.
+
+handle_solver_vars_in_disjs(NonLocals, VarTypes, Disjs0, Disjs,
+ InstMaps0, InstMaps, !ModeInfo) :-
+ mode_info_get_module_info(!.ModeInfo, ModuleInfo),
+ EnsureInitialised =
+ solver_vars_that_must_be_initialised(NonLocals, VarTypes,
+ ModuleInfo, InstMaps0),
+ add_necessary_disj_init_calls(Disjs0, Disjs, InstMaps0, InstMaps,
+ EnsureInitialised, !ModeInfo).
+
+
+:- pred handle_solver_vars_in_ite(list(prog_var)::in,
+ map(prog_var, (type))::in,
+ hlds_goal::in, hlds_goal::out, hlds_goal::in, hlds_goal::out,
+ instmap::in, instmap::out, instmap::in, instmap::out, mode_info::in,
+ mode_info::out) is det.
+
+handle_solver_vars_in_ite(NonLocals, VarTypes,
+ Then0, Then, Else0, Else,
+ ThenInstMap0, ThenInstMap, ElseInstMap0, ElseInstMap,
+ !ModeInfo) :-
+
+ mode_info_get_module_info(!.ModeInfo, ModuleInfo),
+ EnsureInitialised =
+ solver_vars_that_must_be_initialised(NonLocals, VarTypes,
+ ModuleInfo, [ThenInstMap0, ElseInstMap0]),
+
+ ThenVarsToInit =
+ solver_vars_to_init(EnsureInitialised, ModuleInfo,
+ ThenInstMap0),
+ construct_initialisation_calls(ThenVarsToInit, ThenInitCalls,
+ !ModeInfo),
+ Then = append_init_calls_to_goal(ThenInitCalls, Then0),
+ ThenInstMap = set_vars_to_inst_any(ThenVarsToInit, ThenInstMap0),
+
+ ElseVarsToInit =
+ solver_vars_to_init(EnsureInitialised, ModuleInfo,
+ ElseInstMap0),
+ construct_initialisation_calls(ElseVarsToInit, ElseInitCalls,
+ !ModeInfo),
+ Else = append_init_calls_to_goal(ElseInitCalls, Else0),
+ ElseInstMap = set_vars_to_inst_any(ElseVarsToInit, ElseInstMap0).
+
+
+:- func solver_vars_that_must_be_initialised(list(prog_var),
+ map(prog_var, (type)), module_info, list(instmap)) = list(prog_var).
+
+solver_vars_that_must_be_initialised([], _VarTypes, _ModuleInfo, _InstMaps) =
+ [].
+
+solver_vars_that_must_be_initialised([Var | Vars], VarTypes, ModuleInfo,
+ InstMaps) =
+ ( if
+ VarType = VarTypes ^ elem(Var),
+ type_util__type_is_solver_type(ModuleInfo, VarType),
+ list__member(InstMap, InstMaps),
+ instmap__lookup_var(InstMap, Var, Inst),
+ not inst_match__inst_is_free(ModuleInfo, Inst)
+ then
+ [ Var | solver_vars_that_must_be_initialised(Vars, VarTypes,
+ ModuleInfo, InstMaps) ]
+ else
+ solver_vars_that_must_be_initialised(Vars, VarTypes,
+ ModuleInfo, InstMaps)
+ ).
+
+
+:- pred add_necessary_disj_init_calls(list(hlds_goal)::in,
+ list(hlds_goal)::out, list(instmap)::in, list(instmap)::out,
+ list(prog_var)::in, mode_info::in, mode_info::out) is det.
+
+add_necessary_disj_init_calls([], [], [], [], _EnsureInitialised, !ModeInfo).
+
+add_necessary_disj_init_calls([], _, [_ | _], _, _, _, _) :-
+ error("modes.add_necessary_init_calls: mismatched lists").
+
+add_necessary_disj_init_calls([_ | _], _, [], _, _, _, _) :-
+ error("modes.add_necessary_init_calls: mismatched lists").
+
+add_necessary_disj_init_calls([Goal0 | Goals0], [Goal | Goals],
+ [InstMap0 | InstMaps0], [InstMap | InstMaps],
+ EnsureInitialised, !ModeInfo) :-
+ mode_info_get_module_info(!.ModeInfo, ModuleInfo),
+ VarsToInit =
+ solver_vars_to_init(EnsureInitialised, ModuleInfo, InstMap0),
+ construct_initialisation_calls(VarsToInit, InitCalls, !ModeInfo),
+ Goal = append_init_calls_to_goal(InitCalls, Goal0),
+ InstMap = set_vars_to_inst_any(VarsToInit, InstMap0),
+ add_necessary_disj_init_calls(Goals0, Goals, InstMaps0, InstMaps,
+ EnsureInitialised, !ModeInfo).
+
+
+:- func append_init_calls_to_goal(list(hlds_goal), hlds_goal) = hlds_goal.
+
+append_init_calls_to_goal(InitCalls, Goal0) = Goal :-
+ Goal0 = GoalExpr - GoalInfo,
+ (
+ GoalExpr = disj(Disjs0)
+ ->
+ Disjs = list__map(append_init_calls_to_goal(InitCalls),
+ Disjs0),
+ Goal = disj(Disjs) - GoalInfo
+ ;
+ goal_to_conj_list(Goal0, Conjs),
+ conj_list_to_goal(Conjs ++ InitCalls, GoalInfo, Goal)
+ ).
+
+
+:- func flatten_disjs(list(hlds_goal)) = list(hlds_goal).
+
+flatten_disjs(Disjs) = list__foldr(flatten_disj, Disjs, []).
+
+
+:- func flatten_disj(hlds_goal, list(hlds_goal)) = list(hlds_goal).
+
+flatten_disj(Disj, Disjs0) = Disjs :-
+ (
+ Disj = disj(Disjs1) - _GoalInfo
+ ->
+ Disjs = list__foldr(flatten_disj, Disjs1, Disjs0)
+ ;
+ Disjs = [Disj | Disjs0]
+ ).
+
+
+:- func solver_vars_to_init(list(prog_var), module_info, instmap) =
+ list(prog_var).
+
+solver_vars_to_init([], _ModuleInfo, _InstMap) = [].
+
+solver_vars_to_init([Var | Vars], ModuleInfo, InstMap) =
+ ( if
+ instmap__lookup_var(InstMap, Var, Inst),
+ inst_match__inst_is_free(ModuleInfo, Inst)
+ then
+ [Var | solver_vars_to_init(Vars, ModuleInfo, InstMap)]
+ else
+ solver_vars_to_init(Vars, ModuleInfo, InstMap)
+ ).
+
+
+:- func set_vars_to_inst_any(list(prog_var), instmap) = instmap.
+
+set_vars_to_inst_any([], InstMap) = InstMap.
+
+set_vars_to_inst_any([Var | Vars], InstMap0) = InstMap :-
+ instmap__set(InstMap0, Var, any_inst, InstMap1),
+ InstMap = set_vars_to_inst_any(Vars, InstMap1).
+
+%-----------------------------------------------------------------------------%
+
% Modecheck a conjunction without doing any reordering.
% This is used by handle_extra_goals above.
:- pred modecheck_conj_list_no_delay(list(hlds_goal)::in, list(hlds_goal)::out,
@@ -2190,20 +2359,14 @@
io::di, io::uo) is det.
modecheck_disj_list([], [], [], !ModeInfo, !IO).
-modecheck_disj_list([Goal0 | Goals0], Goals, [InstMap | InstMaps],
+modecheck_disj_list([Goal0 | Goals0], [Goal | Goals], [InstMap | InstMaps],
!ModeInfo, !IO) :-
mode_info_get_instmap(!.ModeInfo, InstMap0),
modecheck_goal(Goal0, Goal, !ModeInfo, !IO),
mode_info_get_instmap(!.ModeInfo, InstMap),
mode_info_set_instmap(InstMap0, !ModeInfo),
- modecheck_disj_list(Goals0, Goals1, InstMaps, !ModeInfo, !IO),
- %
- % If Goal is a nested disjunction,
- % then merge it with the outer disjunction.
- % If Goal is `fail', this will delete it.
- %
- goal_to_disj_list(Goal, DisjList),
- list__append(DisjList, Goals1, Goals).
+ modecheck_disj_list(Goals0, Goals, InstMaps, !ModeInfo, !IO).
+
:- pred modecheck_case_list(list(case)::in, prog_var::in, list(case)::out,
list(instmap)::out, mode_info::in, mode_info::out,
Index: tests/hard_coded/Mmakefile
===================================================================
RCS file: /home/mercury1/repository/tests/hard_coded/Mmakefile,v
retrieving revision 1.251
diff -u -r1.251 Mmakefile
--- tests/hard_coded/Mmakefile 15 Mar 2005 02:51:25 -0000 1.251
+++ tests/hard_coded/Mmakefile 18 Mar 2005 00:38:05 -0000
@@ -154,6 +154,8 @@
shift_test \
solve_quadratic \
solver_construction_init_test \
+ solver_disj_inits \
+ solver_ite_inits \
space \
stable_sort \
string_alignment \
Index: tests/hard_coded/solver_disj_inits.exp
===================================================================
RCS file: tests/hard_coded/solver_disj_inits.exp
diff -N tests/hard_coded/solver_disj_inits.exp
--- /dev/null 1 Jan 1970 00:00:00 -0000
+++ tests/hard_coded/solver_disj_inits.exp 18 Mar 2005 00:35:13 -0000
@@ -0,0 +1,3 @@
+0
+1
+2
Index: tests/hard_coded/solver_disj_inits.m
===================================================================
RCS file: tests/hard_coded/solver_disj_inits.m
diff -N tests/hard_coded/solver_disj_inits.m
--- /dev/null 1 Jan 1970 00:00:00 -0000
+++ tests/hard_coded/solver_disj_inits.m 18 Mar 2005 00:35:01 -0000
@@ -0,0 +1,74 @@
+%-----------------------------------------------------------------------------%
+% solver_disj_inits.m
+% Ralph Becket <rafe at cs.mu.oz.au>
+% Fri Mar 18 11:17:41 EST 2005
+% vim: ft=mercury ts=4 sw=4 et wm=0 tw=0
+%
+% Test that the compiler inserts solver variable initialisation calls
+% at the ends of disjuncts if necessary to ensure that solver variables
+% have compatible insts at the end of a disjunction.
+%
+%-----------------------------------------------------------------------------%
+
+:- module solver_disj_inits.
+
+:- interface.
+
+:- import_module io.
+
+
+
+:- pred main(io :: di, io :: uo) is det.
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
+
+:- implementation.
+
+:- import_module int.
+
+
+
+:- solver type foo
+ where representation is int,
+ initialisation is init,
+ ground is ground,
+ any is ground.
+
+
+
+:- pred init(foo::oa) is det.
+:- pragma promise_pure(init/1).
+init(X) :- impure X = 'representation to any foo/0'(0).
+
+:- func foo(int::in) = (foo::oa) is det.
+:- pragma promise_pure(foo/1).
+foo(N) = X :- impure X = 'representation to any foo/0'(N).
+
+:- pred write_foo(foo::ia, io::di, io::uo) is det.
+:- pragma promise_pure(write_foo/3).
+write_foo(Foo, !IO) :-
+ impure X = 'representation of any foo/0'(Foo),
+ io.print(X, !IO),
+ io.nl(!IO).
+
+
+
+:- type bar ---> a ; b ; c.
+
+:- func f(bar::in) = (foo::oa) is det.
+f(Bar) = Foo :-
+ ( Bar = a
+ ; Bar = b, Foo = foo(1)
+ ; Bar = c, Foo = foo(2)
+ ).
+
+%-----------------------------------------------------------------------------%
+
+main(!IO) :-
+ write_foo(f(a), !IO),
+ write_foo(f(b), !IO),
+ write_foo(f(c), !IO).
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
Index: tests/hard_coded/solver_ite_inits.exp
===================================================================
RCS file: tests/hard_coded/solver_ite_inits.exp
diff -N tests/hard_coded/solver_ite_inits.exp
--- /dev/null 1 Jan 1970 00:00:00 -0000
+++ tests/hard_coded/solver_ite_inits.exp 18 Mar 2005 00:37:46 -0000
@@ -0,0 +1,3 @@
+0
+1
+1
Index: tests/hard_coded/solver_ite_inits.m
===================================================================
RCS file: tests/hard_coded/solver_ite_inits.m
diff -N tests/hard_coded/solver_ite_inits.m
--- /dev/null 1 Jan 1970 00:00:00 -0000
+++ tests/hard_coded/solver_ite_inits.m 18 Mar 2005 00:37:27 -0000
@@ -0,0 +1,71 @@
+%-----------------------------------------------------------------------------%
+% solver_ite_inits.m
+% Ralph Becket <rafe at cs.mu.oz.au>
+% Fri Mar 18 11:17:41 EST 2005
+% vim: ft=mercury ts=4 sw=4 et wm=0 tw=0
+%
+% Test that the compiler inserts solver variable initialisation calls
+% at the ends of if-then-else branches if necessary to ensure that solver
+% variables have compatible insts at the end of the if-then-else.
+%
+%-----------------------------------------------------------------------------%
+
+:- module solver_ite_inits.
+
+:- interface.
+
+:- import_module io.
+
+
+
+:- pred main(io :: di, io :: uo) is det.
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
+
+:- implementation.
+
+:- import_module int.
+
+
+
+:- solver type foo
+ where representation is int,
+ initialisation is init,
+ ground is ground,
+ any is ground.
+
+
+
+:- pred init(foo::oa) is det.
+:- pragma promise_pure(init/1).
+init(X) :- impure X = 'representation to any foo/0'(0).
+
+:- func foo(int::in) = (foo::oa) is det.
+:- pragma promise_pure(foo/1).
+foo(N) = X :- impure X = 'representation to any foo/0'(N).
+
+:- pred write_foo(foo::ia, io::di, io::uo) is det.
+:- pragma promise_pure(write_foo/3).
+write_foo(Foo, !IO) :-
+ impure X = 'representation of any foo/0'(Foo),
+ io.print(X, !IO),
+ io.nl(!IO).
+
+
+
+:- type bar ---> a ; b ; c.
+
+:- func f(bar::in) = (foo::oa) is det.
+f(Bar) = Foo :-
+ ( if Bar = a then true else Foo = foo(1) ).
+
+%-----------------------------------------------------------------------------%
+
+main(!IO) :-
+ write_foo(f(a), !IO),
+ write_foo(f(b), !IO),
+ write_foo(f(c), !IO).
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
--------------------------------------------------------------------------
mercury-reviews mailing list
post: mercury-reviews at cs.mu.oz.au
administrative address: owner-mercury-reviews at cs.mu.oz.au
unsubscribe: Address: mercury-reviews-request at cs.mu.oz.au Message: unsubscribe
subscribe: Address: mercury-reviews-request at cs.mu.oz.au Message: subscribe
--------------------------------------------------------------------------
More information about the reviews
mailing list