[m-rev.] for review: Fix some bugs with constrained polymorphic modes.

Peter Wang novalazy at gmail.com
Thu Sep 4 16:58:02 AEST 2014


1. An inst variable from the head of the clause could be bound in the body.
The mode error in the call `P(X)' was not detected because the inst
variable I could be bound to ground.

	:- pred p(pred(T), T).
	:- mode p(pred(in(I =< ground)) is det, in) is det.

	p(P, X) :- P(X).

2. In David Overton's thesis, a get_subst function produces an inst
substitution from the callee's initial insts to the arguments' initial
insts, and the substitution is applied to all insts from the callee.
In the implementation we actually build the substitution while matching
the arguments' insts with the callee's insts, but we lost some precision
by not applying the incremental substitutions built as we match a list
of insts.

3. We rename apart the callee's inst variables from our own, but did not
keep the merged inst varset.  When a renamed inst variable appears in a
mode error it would be printed without its original name.
XXX could this avert more serious problems?

4. inst_merge_3 was missing a case.

compiler/inst_util.m:
	Handle in inst_merge_3 the case
	    InstA \= constrained_inst_var(_, _),
	    InstB = constrained_inst_var(_, _).

	Move InstA = not_reached case to inst_merge_2 to maintain
	inst_merge(not_reached, InstB @ constrained_inst_var(_, _)) = InstB

compiler/mode_util.m:
compiler/modecheck_call.m:
	Keep the merged inst varsets after renaming.

compiler/mode_info.m:
	Add "head inst vars" in mode_info structure.

compiler/modecheck_util.m:
	Add get_constrained_inst_vars to extract constrained inst vars from a
	list of mode.

	Make modecheck_var_has_inst_list_* fail if the computed substitution
	would change the constraints of any head inst variables.

compiler/modes.m:
	Initialise mode_info with head inst variables.

compiler/pd_info.m:
compiler/pd_util.m:
	Conform to change in mode_info_init (not specifically tested).

compiler/prog_mode.m:
	Export inst_apply_substitution.

	Make rename_apart_inst_vars return the merged inst varset.

	Fix comments.

tests/invalid/Mmakefile:
tests/invalid/constrained_poly_insts2.err_exp:
tests/invalid/constrained_poly_insts2.m:
tests/valid/Mmakefile:
tests/valid/constrained_poly_multi.m:
	Add test cases.

diff --git a/compiler/inst_util.m b/compiler/inst_util.m
index a4db519..e5f3a13 100644
--- a/compiler/inst_util.m
+++ b/compiler/inst_util.m
@@ -1712,6 +1712,8 @@ inst_merge_2(InstA, InstB, MaybeType, Inst, !ModuleInfo) :-
     inst_expand(!.ModuleInfo, InstB, ExpandedInstB),
     ( ExpandedInstB = not_reached ->
         Inst = ExpandedInstA
+    ; ExpandedInstA = not_reached ->
+        Inst = ExpandedInstB
     ;
         inst_merge_3(ExpandedInstA, ExpandedInstB, MaybeType, Inst,
             !ModuleInfo)
@@ -1737,8 +1739,11 @@ inst_merge_3(InstA, InstB, MaybeType, Inst, !ModuleInfo) :-
         ;
             inst_merge(SubInstA, InstB, MaybeType, Inst, !ModuleInfo)
         )
+    ; InstB = constrained_inst_vars(_InstVarsB, SubInstB) ->
+        % InstA \= constrained_inst_vars(_, _) is equivalent to
+        % constrained_inst_vars(InstVarsA, InstA) where InstVarsA = empty.
+        inst_merge(InstA, SubInstB, MaybeType, Inst, !ModuleInfo)
     ;
-        % XXX Why don't we look for InstB = constrained_inst_vars/2 here?
         inst_merge_4(InstA, InstB, MaybeType, Inst, !ModuleInfo)
     ).
 
@@ -1863,9 +1868,6 @@ inst_merge_4(InstA, InstB, MaybeType, Inst, !ModuleInfo) :-
         MaybeTypes = list.duplicate(list.length(ArgsA), no),
         inst_list_merge(ArgsA, ArgsB, MaybeTypes, Args, !ModuleInfo),
         Inst = abstract_inst(Name, Args)
-    ;
-        InstA = not_reached,
-        Inst = InstB
     ).
 
     % merge_uniq(A, B, C) succeeds if C is minimum of A and B in the ordering
diff --git a/compiler/mode_info.m b/compiler/mode_info.m
index 362ce00..649365e 100644
--- a/compiler/mode_info.m
+++ b/compiler/mode_info.m
@@ -106,8 +106,8 @@
 
     % Initialize the mode_info.
     %
-:- pred mode_info_init(module_info::in, pred_id::in,
-    proc_id::in, prog_context::in, set_of_progvar::in, instmap::in,
+:- pred mode_info_init(module_info::in, pred_id::in, proc_id::in,
+    prog_context::in, set_of_progvar::in, head_inst_vars::in, instmap::in,
     how_to_check_goal::in, may_change_called_proc::in, mode_info::out) is det.
 
     % The mode_info contains a flag indicating whether initialisation calls,
@@ -190,6 +190,8 @@
 :- pred mode_info_get_may_change_called_proc(mode_info::in,
     may_change_called_proc::out) is det.
 :- pred mode_info_get_initial_instmap(mode_info::in, instmap::out) is det.
+:- pred mode_info_get_head_inst_vars(mode_info::in, head_inst_vars::out)
+    is det.
 :- pred mode_info_get_checking_extra_goals(mode_info::in, bool::out) is det.
 :- pred mode_info_get_may_init_solver_vars(mode_info::in,
     may_init_solver_vars::out) is det.
@@ -224,6 +226,8 @@
     mode_info::in, mode_info::out) is det.
 :- pred mode_info_set_locked_vars(locked_vars::in,
     mode_info::in, mode_info::out) is det.
+:- pred mode_info_set_instvarset(inst_varset::in,
+    mode_info::in, mode_info::out) is det.
 :- pred mode_info_set_errors(list(mode_error_info)::in,
     mode_info::in, mode_info::out) is det.
 :- pred mode_info_set_warnings(list(mode_warning_info)::in,
@@ -337,11 +341,11 @@
 :- import_module check_hlds.delay_info.
 :- import_module check_hlds.mode_errors.
 :- import_module libs.
+:- import_module map.
 :- import_module libs.globals.
 :- import_module libs.options.
 
 :- import_module int.
-:- import_module map.
 :- import_module pair.
 :- import_module require.
 :- import_module term.
@@ -421,6 +425,10 @@
                 % by an argument mode that enforces a subtype.
                 msi_initial_instmap         :: instmap,
 
+                % The inst vars that appear in the procedure head and their
+                % constraints.
+                msi_head_inst_vars          :: head_inst_vars,
+
                 % The mode warnings found.
                 msi_warnings                :: list(mode_warning_info),
 
@@ -491,8 +499,8 @@
 
 %-----------------------------------------------------------------------------%
 
-mode_info_init(ModuleInfo, PredId, ProcId, Context, LiveVars, InstMap0,
-        HowToCheck, MayChangeProc, ModeInfo) :-
+mode_info_init(ModuleInfo, PredId, ProcId, Context, LiveVars, HeadInstVars,
+        InstMap0, HowToCheck, MayChangeProc, ModeInfo) :-
     module_info_get_globals(ModuleInfo, Globals),
     globals.lookup_bool_option(Globals, debug_modes, DebugModes),
     globals.lookup_int_option(Globals, debug_modes_pred_id,
@@ -536,9 +544,9 @@ mode_info_init(ModuleInfo, PredId, ProcId, Context, LiveVars, InstMap0,
     ModeSubInfo = mode_sub_info(PredId, ProcId, VarSet, VarTypes, Debug,
         LockedVars, LiveVarsBag, InstVarSet, ParallelVars, HowToCheck,
         MayChangeProc, MayInitSolverVars, LastCheckpointInstMap, Changed,
-        CheckingExtraGoals, InstMap0, WarningList, NeedToRequantify,
-        InPromisePurityScope, InFromGroundTerm, HadFromGroundTerm,
-        MakeGroundTermsUnique, InDuplForSwitch),
+        CheckingExtraGoals, InstMap0, HeadInstVars, WarningList,
+        NeedToRequantify, InPromisePurityScope, InFromGroundTerm,
+        HadFromGroundTerm, MakeGroundTermsUnique, InDuplForSwitch),
 
     mode_context_init(ModeContext),
     delay_info_init(DelayInfo),
@@ -581,6 +589,7 @@ mode_info_get_may_change_called_proc(MI,
 mode_info_get_may_init_solver_vars(MI,
     MI ^ mi_sub_info ^ msi_may_init_solver_vars).
 mode_info_get_initial_instmap(MI, MI ^ mi_sub_info ^ msi_initial_instmap).
+mode_info_get_head_inst_vars(MI, MI ^ mi_sub_info ^ msi_head_inst_vars).
 mode_info_get_checking_extra_goals(MI,
     MI ^ mi_sub_info ^ msi_checking_extra_goals).
 mode_info_get_in_from_ground_term(MI,
@@ -603,6 +612,8 @@ mode_info_set_mode_context(ModeContext,
     MI, MI ^ mi_mode_context := ModeContext).
 mode_info_set_locked_vars(LockedVars, MI,
     MI ^ mi_sub_info ^ msi_locked_vars := LockedVars).
+mode_info_set_instvarset(InstVarSet, MI,
+    MI ^ mi_sub_info ^ msi_instvarset := InstVarSet).
 mode_info_set_errors(Errors, MI, MI ^ mi_errors := Errors).
 mode_info_set_warnings(Warnings, MI,
     MI ^ mi_sub_info ^ msi_warnings := Warnings).
diff --git a/compiler/mode_util.m b/compiler/mode_util.m
index 02a9518..6dd6167 100644
--- a/compiler/mode_util.m
+++ b/compiler/mode_util.m
@@ -1471,9 +1471,10 @@ recompute_instmap_delta_call(PredId, ProcId, Args, VarTypes, InstMap,
     ;
         proc_info_get_argmodes(ProcInfo, ArgModes0),
         proc_info_get_inst_varset(ProcInfo, ProcInstVarSet),
-        InstVarSet = !.RI ^ ri_inst_varset,
-        rename_apart_inst_vars(InstVarSet, ProcInstVarSet,
+        InstVarSet0 = !.RI ^ ri_inst_varset,
+        rename_apart_inst_vars(InstVarSet0, ProcInstVarSet, InstVarSet,
             ArgModes0, ArgModes1),
+        !RI ^ ri_inst_varset := InstVarSet,
         mode_list_get_initial_insts(ModuleInfo0, ArgModes1, InitialInsts),
 
         % Compute the inst_var substitution from the initial insts
diff --git a/compiler/modecheck_call.m b/compiler/modecheck_call.m
index 24f676a..32d96a2 100644
--- a/compiler/modecheck_call.m
+++ b/compiler/modecheck_call.m
@@ -157,9 +157,10 @@ modecheck_call_pred(PredId, DeterminismKnown, ProcId0, TheProcId,
         %
         proc_info_get_argmodes(ProcInfo, ProcArgModes0),
         proc_info_get_inst_varset(ProcInfo, ProcInstVarSet),
-        mode_info_get_instvarset(!.ModeInfo, InstVarSet),
-        rename_apart_inst_vars(InstVarSet, ProcInstVarSet,
+        mode_info_get_instvarset(!.ModeInfo, InstVarSet0),
+        rename_apart_inst_vars(InstVarSet0, ProcInstVarSet, InstVarSet,
             ProcArgModes0, ProcArgModes),
+        mode_info_set_instvarset(InstVarSet, !ModeInfo),
         mode_list_get_initial_insts(ModuleInfo, ProcArgModes, InitialInsts),
         modecheck_var_has_inst_list_no_exact_match(ArgVars0, InitialInsts,
             ArgOffset, InstVarSub, !ModeInfo),
@@ -375,9 +376,10 @@ modecheck_find_matching_modes([ProcId | ProcIds], PredId, Procs, ArgVars0,
     map.lookup(Procs, ProcId, ProcInfo),
     proc_info_get_argmodes(ProcInfo, ProcArgModes0),
     proc_info_get_inst_varset(ProcInfo, ProcInstVarSet),
-    mode_info_get_instvarset(!.ModeInfo, InstVarSet),
-    rename_apart_inst_vars(InstVarSet, ProcInstVarSet, ProcArgModes0,
-        ProcArgModes),
+    mode_info_get_instvarset(!.ModeInfo, InstVarSet0),
+    rename_apart_inst_vars(InstVarSet0, ProcInstVarSet, InstVarSet,
+        ProcArgModes0, ProcArgModes),
+    mode_info_set_instvarset(InstVarSet, !ModeInfo),
     mode_info_get_module_info(!.ModeInfo, ModuleInfo),
     proc_info_arglives(ProcInfo, ModuleInfo, ProcArgLives0),
 
diff --git a/compiler/modecheck_util.m b/compiler/modecheck_util.m
index 963016b..863be92 100644
--- a/compiler/modecheck_util.m
+++ b/compiler/modecheck_util.m
@@ -16,6 +16,7 @@
 :- import_module check_hlds.mode_info.
 :- import_module hlds.
 :- import_module hlds.hlds_goal.
+:- import_module hlds.hlds_module.
 :- import_module hlds.hlds_pred.
 :- import_module hlds.instmap.
 :- import_module parse_tree.
@@ -176,6 +177,12 @@
 :- pred get_live_vars(list(prog_var)::in, list(is_live)::in,
     list(prog_var)::out) is det.
 
+    % Return a map of all the inst variables in the given modes, and the
+    % sub-insts to which they are constrained.
+    %
+:- pred get_constrained_inst_vars(module_info::in, list(mer_mode)::in,
+    head_inst_vars::out) is det.
+
 %-----------------------------------------------------------------------------%
 %-----------------------------------------------------------------------------%
 
@@ -185,16 +192,17 @@
 :- import_module check_hlds.inst_match.
 :- import_module check_hlds.inst_util.
 :- import_module check_hlds.mode_errors.
+:- import_module check_hlds.mode_util.
 :- import_module check_hlds.modecheck_goal.
 :- import_module check_hlds.modecheck_unify.
 :- import_module check_hlds.polymorphism.
 :- import_module check_hlds.type_util.
-:- import_module hlds.hlds_module.
 :- import_module hlds.pred_table.
 :- import_module hlds.special_pred.
 :- import_module mdbcomp.
 :- import_module mdbcomp.prim_data.
 :- import_module mdbcomp.sym_name.
+:- import_module parse_tree.prog_mode.
 :- import_module parse_tree.prog_type.
 :- import_module parse_tree.set_of_var.
 
@@ -205,9 +213,13 @@
 :- import_module pair.
 :- import_module require.
 :- import_module set.
+:- import_module set_tree234.
 :- import_module term.
+:- import_module unit.
 :- import_module varset.
 
+:- type expansions == set_tree234(inst_name).
+
 %-----------------------------------------------------------------------------%
 
 append_extra_goals(no_extra_goals, ExtraGoals, ExtraGoals).
@@ -463,12 +475,14 @@ modecheck_var_is_live_exact_match(VarId, ExpectedIsLive, !ModeInfo) :-
 modecheck_var_has_inst_list_exact_match(Vars, Insts, ArgNum, Subst,
         !ModeInfo) :-
     modecheck_var_has_inst_list_exact_match_2(Vars, Insts, ArgNum,
-        map.init, Subst, !ModeInfo).
+        map.init, Subst, !ModeInfo),
+    modecheck_head_inst_vars(Vars, Subst, !ModeInfo).
 
 modecheck_var_has_inst_list_no_exact_match(Vars, Insts, ArgNum, Subst,
         !ModeInfo) :-
     modecheck_var_has_inst_list_no_exact_match_2(Vars, Insts, ArgNum,
-        map.init, Subst, !ModeInfo).
+        map.init, Subst, !ModeInfo),
+    modecheck_head_inst_vars(Vars, Subst, !ModeInfo).
 
 :- pred modecheck_var_has_inst_list_exact_match_2(list(prog_var)::in,
     list(mer_inst)::in, int::in, inst_var_sub::in, inst_var_sub::out,
@@ -511,7 +525,9 @@ modecheck_var_has_inst_list_no_exact_match_2([Var | Vars], [Inst | Insts],
     inst_var_sub::in, inst_var_sub::out,
     mode_info::in, mode_info::out) is det.
 
-modecheck_var_has_inst_exact_match(Var, Inst, !Subst, !ModeInfo) :-
+modecheck_var_has_inst_exact_match(Var, Inst0, !Subst, !ModeInfo) :-
+    % Apply the substitution computed while matching earlier arguments.
+    inst_apply_substitution(!.Subst, Inst0, Inst),
     mode_info_get_instmap(!.ModeInfo, InstMap),
     instmap_lookup_var(InstMap, Var, VarInst),
     mode_info_get_var_types(!.ModeInfo, VarTypes),
@@ -532,7 +548,9 @@ modecheck_var_has_inst_exact_match(Var, Inst, !Subst, !ModeInfo) :-
     inst_var_sub::in, inst_var_sub::out,
     mode_info::in, mode_info::out) is det.
 
-modecheck_var_has_inst_no_exact_match(Var, Inst, !Subst, !ModeInfo) :-
+modecheck_var_has_inst_no_exact_match(Var, Inst0, !Subst, !ModeInfo) :-
+    % Apply the substitution computed while matching earlier arguments.
+    inst_apply_substitution(!.Subst, Inst0, Inst),
     mode_info_get_instmap(!.ModeInfo, InstMap),
     instmap_lookup_var(InstMap, Var, VarInst),
     mode_info_get_var_types(!.ModeInfo, VarTypes),
@@ -567,6 +585,35 @@ modecheck_introduced_type_info_var_has_inst_no_exact_match(Var, Type, Inst,
 
 %-----------------------------------------------------------------------------%
 
+:- pred modecheck_head_inst_vars(list(prog_var)::in, inst_var_sub::in,
+    mode_info::in, mode_info::out) is det.
+
+modecheck_head_inst_vars(Vars, InstVarSub, !ModeInfo) :-
+    mode_info_get_head_inst_vars(!.ModeInfo, HeadInstVars),
+    ( map.foldl(modecheck_head_inst_var(HeadInstVars), InstVarSub, unit, _) ->
+        true
+    ;
+        mode_info_get_instmap(!.ModeInfo, InstMap),
+        instmap_lookup_vars(InstMap, Vars, VarInsts),
+        WaitingVars = set_of_var.list_to_set(Vars),
+        ModeError = mode_error_no_matching_mode(Vars, VarInsts),
+        mode_info_error(WaitingVars, ModeError, !ModeInfo)
+    ).
+
+:- pred modecheck_head_inst_var(inst_var_sub::in, inst_var::in, mer_inst::in,
+    unit::in, unit::out) is semidet.
+
+modecheck_head_inst_var(HeadInstVars, InstVar, Subst, !Acc) :-
+    ( map.search(HeadInstVars, InstVar, Inst) ->
+        % Subst should not change the constraint.
+        Subst = constrained_inst_vars(InstVars, Inst),
+        set.member(InstVar, InstVars)
+    ;
+        true
+    ).
+
+%-----------------------------------------------------------------------------%
+
 modecheck_set_var_inst_list(Vars0, InitialInsts, FinalInsts, ArgOffset,
         Vars, Goals, !ModeInfo) :-
     (
@@ -1016,5 +1063,101 @@ get_live_vars([Var | Vars], [IsLive | IsLives], LiveVars) :-
     get_live_vars(Vars, IsLives, LiveVars0).
 
 %-----------------------------------------------------------------------------%
+
+get_constrained_inst_vars(ModuleInfo, Modes, Map) :-
+    list.foldl2(get_constrained_insts_in_mode(ModuleInfo), Modes,
+        map.init, Map, set_tree234.init, _Expansions).
+
+:- pred get_constrained_insts_in_mode(module_info::in, mer_mode::in,
+    head_inst_vars::in, head_inst_vars::out, expansions::in, expansions::out)
+    is det.
+
+get_constrained_insts_in_mode(ModuleInfo, Mode, !Map, !Expansions) :-
+    mode_get_insts(ModuleInfo, Mode, InitialInst, FinalInst),
+    get_constrained_insts_in_inst(ModuleInfo, InitialInst, !Map, !Expansions),
+    get_constrained_insts_in_inst(ModuleInfo, FinalInst, !Map, !Expansions).
+
+:- pred get_constrained_insts_in_inst(module_info::in, mer_inst::in,
+    head_inst_vars::in, head_inst_vars::out, expansions::in, expansions::out)
+    is det.
+
+get_constrained_insts_in_inst(ModuleInfo, Inst, !Map, !Expansions) :-
+    (
+        ( Inst = free
+        ; Inst = free(_)
+        ; Inst = not_reached
+        )
+    ;
+        Inst = bound(_, _, BoundInsts),
+        list.foldl2(get_constrained_insts_in_bound_inst(ModuleInfo),
+            BoundInsts, !Map, !Expansions)
+    ;
+        ( Inst = any(_, HOInstInfo)
+        ; Inst = ground(_, HOInstInfo)
+        ),
+        (
+            HOInstInfo = none
+        ;
+            HOInstInfo = higher_order(PredInstInfo),
+            get_constrained_insts_in_ho_inst(ModuleInfo, PredInstInfo,
+                !Map, !Expansions)
+        )
+    ;
+        Inst = constrained_inst_vars(InstVars, _),
+        inst_expand_and_remove_constrained_inst_vars(ModuleInfo,
+            Inst, SubInst),
+        set.fold(add_constrained_inst(SubInst), InstVars, !Map)
+    ;
+        Inst = defined_inst(InstName),
+        ( insert_new(InstName, !Expansions) ->
+            inst_lookup(ModuleInfo, InstName, ExpandedInst),
+            get_constrained_insts_in_inst(ModuleInfo, ExpandedInst,
+                !Map, !Expansions)
+        ;
+            true
+        )
+    ;
+        Inst = inst_var(_),
+        unexpected($module, $pred, "inst_var")
+    ;
+        Inst = abstract_inst(_, _),
+        sorry($module, $pred, "abstract_inst")
+    ).
+
+:- pred get_constrained_insts_in_bound_inst(module_info::in, bound_inst::in,
+    head_inst_vars::in, head_inst_vars::out, expansions::in, expansions::out)
+    is det.
+
+get_constrained_insts_in_bound_inst(ModuleInfo, BoundInst, !Map, !Expansions)
+        :-
+    BoundInst = bound_functor(_ConsId, Insts),
+    list.foldl2(get_constrained_insts_in_inst(ModuleInfo), Insts,
+        !Map, !Expansions).
+
+:- pred get_constrained_insts_in_ho_inst(module_info::in, pred_inst_info::in,
+    head_inst_vars::in, head_inst_vars::out, expansions::in, expansions::out)
+    is det.
+
+get_constrained_insts_in_ho_inst(ModuleInfo, PredInstInfo, !Map, !Expansions)
+        :-
+    PredInstInfo = pred_inst_info(_, Modes, _, _),
+    list.foldl2(get_constrained_insts_in_mode(ModuleInfo), Modes,
+        !Map, !Expansions).
+
+:- pred add_constrained_inst(mer_inst::in, inst_var::in,
+    head_inst_vars::in, head_inst_vars::out) is det.
+
+add_constrained_inst(SubInst, InstVar, !Map) :-
+    ( map.search(!.Map, InstVar, SubInst0) ->
+        ( SubInst0 = SubInst ->
+            true
+        ;
+            unexpected($module, $pred, "SubInst differs")
+        )
+    ;
+        map.det_insert(InstVar, SubInst, !Map)
+    ).
+
+%-----------------------------------------------------------------------------%
 :- end_module check_hlds.modecheck_util.
 %-----------------------------------------------------------------------------%
diff --git a/compiler/modes.m b/compiler/modes.m
index 43b12bc..1384fd7 100644
--- a/compiler/modes.m
+++ b/compiler/modes.m
@@ -687,9 +687,12 @@ do_modecheck_proc(ProcId, PredId, WhatToCheck, MayChangeCalledProc,
         get_live_vars(HeadVars, ArgLives0, LiveVarsList),
         set_of_var.list_to_set(LiveVarsList, LiveVars),
 
+        get_constrained_inst_vars(!.ModuleInfo, ArgModes0, HeadInstVars),
+
         % Initialize the mode info.
         mode_info_init(!.ModuleInfo, PredId, ProcId, Context, LiveVars,
-            InstMap0, WhatToCheck, MayChangeCalledProc, !:ModeInfo),
+            HeadInstVars, InstMap0, WhatToCheck, MayChangeCalledProc,
+            !:ModeInfo),
         mode_info_set_changed_flag(!.Changed, !ModeInfo),
 
         pred_info_get_markers(PredInfo, Markers),
diff --git a/compiler/pd_info.m b/compiler/pd_info.m
index cc851f9..70842d8 100644
--- a/compiler/pd_info.m
+++ b/compiler/pd_info.m
@@ -117,6 +117,7 @@
 :- implementation.
 
 :- import_module check_hlds.inst_match.
+:- import_module check_hlds.modecheck_util.
 :- import_module hlds.hlds_goal.
 :- import_module hlds.hlds_pred.
 :- import_module mdbcomp.prim_data.
@@ -146,12 +147,14 @@ pd_info_init(ModuleInfo, ProcArgInfos, PDInfo) :-
 
 pd_info_init_unfold_info(PredProcId, PredInfo, ProcInfo, !PDInfo) :-
     pd_info_get_module_info(!.PDInfo, ModuleInfo),
+    proc_info_get_argmodes(ProcInfo, ArgModes),
+    get_constrained_inst_vars(ModuleInfo, ArgModes, HeadInstVars),
     proc_info_get_initial_instmap(ProcInfo, ModuleInfo, InstMap),
     CostDelta = 0,
     pd_term.local_term_info_init(LocalTermInfo),
     Parents = set.make_singleton_set(PredProcId),
-    UnfoldInfo = unfold_info(ProcInfo, InstMap, CostDelta, LocalTermInfo,
-        PredInfo, Parents, PredProcId, no, 0, no),
+    UnfoldInfo = unfold_info(ProcInfo, HeadInstVars, InstMap, CostDelta,
+        LocalTermInfo, PredInfo, Parents, PredProcId, no, 0, no),
     pd_info_set_unfold_info(UnfoldInfo, !PDInfo).
 
 pd_info_get_module_info(PDInfo, PDInfo ^ pdi_module_info).
@@ -224,6 +227,7 @@ pd_info_bind_var_to_functors(Var, MainConsId, OtherConsIds, !PDInfo) :-
 :- type unfold_info
     --->    unfold_info(
                 ufi_proc_info       :: proc_info,
+                ufi_head_inst_vars  :: map(inst_var, mer_inst),
                 ufi_instmap         :: instmap,
 
                 % Improvement in cost measured while processing this procedure.
@@ -275,6 +279,8 @@ pd_info_bind_var_to_functors(Var, MainConsId, OtherConsIds, !PDInfo) :-
 :- type branch_info_map(T)  ==  map(T, set(int)).
 
 :- pred pd_info_get_proc_info(pd_info::in, proc_info::out) is det.
+:- pred pd_info_get_head_inst_vars(pd_info::in, map(inst_var, mer_inst)::out)
+    is det.
 :- pred pd_info_get_instmap(pd_info::in, instmap::out) is det.
 :- pred pd_info_get_cost_delta(pd_info::in, int::out) is det.
 :- pred pd_info_get_local_term_info(pd_info::in, local_term_info::out) is det.
@@ -316,6 +322,8 @@ pd_info_bind_var_to_functors(Var, MainConsId, OtherConsIds, !PDInfo) :-
 
 pd_info_get_proc_info(PDInfo, UnfoldInfo ^ ufi_proc_info) :-
     pd_info_get_unfold_info(PDInfo, UnfoldInfo).
+pd_info_get_head_inst_vars(PDInfo, UnfoldInfo ^ ufi_head_inst_vars) :-
+    pd_info_get_unfold_info(PDInfo, UnfoldInfo).
 pd_info_get_instmap(PDInfo, UnfoldInfo ^ ufi_instmap) :-
     pd_info_get_unfold_info(PDInfo, UnfoldInfo).
 pd_info_get_cost_delta(PDInfo, UnfoldInfo ^ ufi_cost_delta) :-
diff --git a/compiler/pd_util.m b/compiler/pd_util.m
index e065afe..83d8169 100644
--- a/compiler/pd_util.m
+++ b/compiler/pd_util.m
@@ -272,14 +272,16 @@ unique_modecheck_goal_live_vars(LiveVars, Goal0, Goal, Errors, !PDInfo) :-
     term.context_init(Context),
     pd_info_get_pred_info(!.PDInfo, PredInfo0),
     pd_info_get_proc_info(!.PDInfo, ProcInfo0),
+    pd_info_get_head_inst_vars(!.PDInfo, HeadInstVars),
     module_info_set_pred_proc_info(PredId, ProcId, PredInfo0, ProcInfo0,
         ModuleInfo0, ModuleInfo1),
 
     % If we perform generalisation, we shouldn't change any called procedures,
     % since that could cause a less efficient version to be chosen.
     MayChangeCalledProc = may_not_change_called_proc,
-    mode_info_init(ModuleInfo1, PredId, ProcId, Context, LiveVars, InstMap0,
-        check_unique_modes, MayChangeCalledProc, ModeInfo0),
+    mode_info_init(ModuleInfo1, PredId, ProcId, Context, LiveVars,
+        HeadInstVars, InstMap0, check_unique_modes, MayChangeCalledProc,
+        ModeInfo0),
 
     unique_modes_check_goal(Goal0, Goal, ModeInfo0, ModeInfo),
     mode_info_get_module_info(ModeInfo, ModuleInfo),
diff --git a/compiler/prog_data.m b/compiler/prog_data.m
index c390ef3..5646b91 100644
--- a/compiler/prog_data.m
+++ b/compiler/prog_data.m
@@ -2430,6 +2430,7 @@ get_type_kind(kinded_type(_, Kind)) = Kind.
 :- type inst_term   ==  term(inst_var_type).
 :- type inst_varset ==  varset(inst_var_type).
 
+:- type head_inst_vars  ==  map(inst_var, mer_inst).
 :- type inst_var_sub    ==  map(inst_var, mer_inst).
 
 % inst_defn/5 is defined in prog_item.m.
diff --git a/compiler/prog_mode.m b/compiler/prog_mode.m
index 78d824e..a535eae 100644
--- a/compiler/prog_mode.m
+++ b/compiler/prog_mode.m
@@ -81,11 +81,17 @@
     mer_inst::in, mer_inst::out) is det.
 
     % inst_list_apply_substitution(Subst, Insts0, Insts) is true
-    % iff Inst is the inst that results from applying Subst to Insts0.
+    % iff Insts is the result of applying Subst to every inst in Insts0.
     %
 :- pred inst_list_apply_substitution(inst_var_sub::in,
     list(mer_inst)::in, list(mer_inst)::out) is det.
 
+    % inst_apply_substitution(Inst0, Subst, Inst) is true iff Inst is the inst
+    % that results from applying Subst to Inst0.
+    %
+:- pred inst_apply_substitution(inst_var_sub::in, mer_inst::in, mer_inst::out)
+    is det.
+
     % mode_list_apply_substitution(Subst, Modes0, Modes) is true
     % iff Mode is the mode that results from applying Subst to Modes0.
     %
@@ -93,7 +99,7 @@
     list(mer_mode)::in, list(mer_mode)::out) is det.
 
 :- pred rename_apart_inst_vars(inst_varset::in, inst_varset::in,
-    list(mer_mode)::in, list(mer_mode)::out) is det.
+    inst_varset::out, list(mer_mode)::in, list(mer_mode)::out) is det.
 
     % inst_contains_unconstrained_var(Inst) iff Inst includes an
     % unconstrained inst variable.
@@ -299,12 +305,6 @@ inst_list_apply_substitution_2(Subst, [A0 | As0], [A | As]) :-
     inst_apply_substitution(Subst, A0, A),
     inst_list_apply_substitution_2(Subst, As0, As).
 
-    % inst_substitute_arg(Inst0, Subst, Inst) is true iff Inst is the inst that
-    % results from substituting all occurrences of Param in Inst0 with Arg.
-    %
-:- pred inst_apply_substitution(inst_var_sub::in, mer_inst::in, mer_inst::out)
-    is det.
-
 inst_apply_substitution(Subst, Inst0, Inst) :-
     (
         ( Inst0 = not_reached
@@ -432,8 +432,8 @@ mode_list_apply_substitution_2(Subst, [A0 | As0], [A | As]) :-
 
 %-----------------------------------------------------------------------------%
 
-rename_apart_inst_vars(VarSet, NewVarSet, Modes0, Modes) :-
-    varset.merge_subst(VarSet, NewVarSet, _, Sub),
+rename_apart_inst_vars(VarSet, NewVarSet, MergedVarSet, Modes0, Modes) :-
+    varset.merge_subst(VarSet, NewVarSet, MergedVarSet, Sub),
     list.map(rename_apart_inst_vars_in_mode(Sub), Modes0, Modes).
 
 :- pred rename_apart_inst_vars_in_mode(substitution(inst_var_type)::in,
diff --git a/tests/invalid/Mmakefile b/tests/invalid/Mmakefile
index 18fb94e..9715f64 100644
--- a/tests/invalid/Mmakefile
+++ b/tests/invalid/Mmakefile
@@ -75,6 +75,7 @@ SINGLEMODULE= \
 	conflicting_fs \
 	conflicting_tabling_pragmas \
 	constrained_poly_insts \
+	constrained_poly_insts2 \
 	constraint_proof_bug_lib \
 	constructor_warning \
 	cyclic_typeclass \
diff --git a/tests/invalid/constrained_poly_insts2.err_exp b/tests/invalid/constrained_poly_insts2.err_exp
new file mode 100644
index 0000000..858b4be
--- /dev/null
+++ b/tests/invalid/constrained_poly_insts2.err_exp
@@ -0,0 +1,64 @@
+constrained_poly_insts2.m:041: In clause for `test6':
+constrained_poly_insts2.m:041:   in argument 2 of call to predicate
+constrained_poly_insts2.m:041:   `constrained_poly_insts2.q'/2:
+constrained_poly_insts2.m:041:   mode error: variable `V_2' has
+constrained_poly_insts2.m:041:   instantiatedness
+constrained_poly_insts2.m:041:   `unique((constrained_poly_insts2.orange))',
+constrained_poly_insts2.m:041:   expected instantiatedness was
+constrained_poly_insts2.m:041:   `bound((constrained_poly_insts2.apple))'.
+constrained_poly_insts2.m:043: In clause for `test8':
+constrained_poly_insts2.m:043:   in argument 2 of call to predicate
+constrained_poly_insts2.m:043:   `constrained_poly_insts2.q'/2:
+constrained_poly_insts2.m:043:   mode error: variable `V_2' has
+constrained_poly_insts2.m:043:   instantiatedness
+constrained_poly_insts2.m:043:   `unique((constrained_poly_insts2.orange))',
+constrained_poly_insts2.m:043:   expected instantiatedness was `(I =<
+constrained_poly_insts2.m:043:   bound((constrained_poly_insts2.apple)))'.
+constrained_poly_insts2.m:048: In clause for `p((pred(in((I =< ground))) is
+constrained_poly_insts2.m:048:   det), in)':
+constrained_poly_insts2.m:048:   in argument 2 (i.e. argument 1 of the called
+constrained_poly_insts2.m:048:   predicate) of higher-order predicate call:
+constrained_poly_insts2.m:048:   mode error: arguments `X' have the following
+constrained_poly_insts2.m:048:   insts:
+constrained_poly_insts2.m:048:     ground
+constrained_poly_insts2.m:048:   which does not match any of the modes for
+constrained_poly_insts2.m:048:   higher-order predicate call.
+constrained_poly_insts2.m:053: In clause for `p2((pred(in((I =< ground))) is
+constrained_poly_insts2.m:053:   det), in((J =<
+constrained_poly_insts2.m:053:   (constrained_poly_insts2.apple))))':
+constrained_poly_insts2.m:053:   in argument 2 (i.e. argument 1 of the called
+constrained_poly_insts2.m:053:   predicate) of higher-order predicate call:
+constrained_poly_insts2.m:053:   mode error: arguments `X' have the following
+constrained_poly_insts2.m:053:   insts:
+constrained_poly_insts2.m:053:     (J =< bound(apple))
+constrained_poly_insts2.m:053:   which does not match any of the modes for
+constrained_poly_insts2.m:053:   higher-order predicate call.
+constrained_poly_insts2.m:058: In clause for `p3((pred(in((I =<
+constrained_poly_insts2.m:058:   (constrained_poly_insts2.apple)))) is det),
+constrained_poly_insts2.m:058:   in((constrained_poly_insts2.orange)))':
+constrained_poly_insts2.m:058:   in argument 2 (i.e. argument 1 of the called
+constrained_poly_insts2.m:058:   predicate) of higher-order predicate call:
+constrained_poly_insts2.m:058:   mode error: variable `X' has instantiatedness
+constrained_poly_insts2.m:058:   `bound(orange)',
+constrained_poly_insts2.m:058:   expected instantiatedness was `(I =<
+constrained_poly_insts2.m:058:   bound(apple))'.
+constrained_poly_insts2.m:063: In clause for `p4((pred(in((I =<
+constrained_poly_insts2.m:063:   (constrained_poly_insts2.apple)))) is det),
+constrained_poly_insts2.m:063:   in((J =< (constrained_poly_insts2.apple))))':
+constrained_poly_insts2.m:063:   in argument 2 (i.e. argument 1 of the called
+constrained_poly_insts2.m:063:   predicate) of higher-order predicate call:
+constrained_poly_insts2.m:063:   mode error: arguments `X' have the following
+constrained_poly_insts2.m:063:   insts:
+constrained_poly_insts2.m:063:     (J =< bound(apple))
+constrained_poly_insts2.m:063:   which does not match any of the modes for
+constrained_poly_insts2.m:063:   higher-order predicate call.
+constrained_poly_insts2.m:068: In clause for `p5((pred(in((I =<
+constrained_poly_insts2.m:068:   (constrained_poly_insts2.apple)))) is det),
+constrained_poly_insts2.m:068:   in((constrained_poly_insts2.apple)))':
+constrained_poly_insts2.m:068:   in argument 2 (i.e. argument 1 of the called
+constrained_poly_insts2.m:068:   predicate) of higher-order predicate call:
+constrained_poly_insts2.m:068:   mode error: arguments `X' have the following
+constrained_poly_insts2.m:068:   insts:
+constrained_poly_insts2.m:068:     bound(apple)
+constrained_poly_insts2.m:068:   which does not match any of the modes for
+constrained_poly_insts2.m:068:   higher-order predicate call.
diff --git a/tests/invalid/constrained_poly_insts2.m b/tests/invalid/constrained_poly_insts2.m
new file mode 100644
index 0000000..c2ee026
--- /dev/null
+++ b/tests/invalid/constrained_poly_insts2.m
@@ -0,0 +1,86 @@
+%-----------------------------------------------------------------------------%
+
+:- module constrained_poly_insts2.
+:- interface.
+
+:- pred test1 is det.
+:- pred test2 is det.
+:- pred test3 is det.
+:- pred test4 is det.
+
+:- pred test5 is det.
+:- pred test6 is det.
+:- pred test7 is det.
+:- pred test8 is det.
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
+
+:- implementation.
+
+:- type fruit --->  apple ; orange.
+:- inst apple --->  apple.
+:- inst orange ---> orange.
+
+:- pred apple1(fruit).
+:- mode apple1(in(apple)) is det.
+
+apple1(_).
+
+:- pred apple2(fruit).
+:- mode apple2(in(I =< apple)) is det.
+
+apple2(_).
+
+test1 :- p(apple1, apple).
+test2 :- p(apple1, orange).
+test3 :- p(apple2, apple).
+test4 :- p(apple2, orange).
+
+test5 :- q(apple1, apple).
+test6 :- q(apple1, orange). % error
+test7 :- q(apple2, apple).
+test8 :- q(apple2, orange). % error
+
+:- pred p(pred(T), T).
+:- mode p(pred(in(I =< ground)) is det, in) is det.
+
+p(P, X) :- P(X). % error
+
+:- pred p2(pred(T), T).
+:- mode p2(pred(in(I =< ground)) is det, in(J =< apple)) is det.
+
+p2(P, X) :- P(X). % error
+
+:- pred p3(pred(T), T).
+:- mode p3(pred(in(I =< apple)) is det, in(orange)) is det.
+
+p3(P, X) :- P(X). % error
+
+:- pred p4(pred(T), T).
+:- mode p4(pred(in(I =< apple)) is det, in(J =< apple)) is det.
+
+p4(P, X) :- P(X). % error
+
+:- pred p5(pred(T), T).
+:- mode p5(pred(in(I =< apple)) is det, in(apple)) is det.
+
+p5(P, X) :- P(X). % error
+
+:- pred q(pred(T), T).
+:- mode q(pred(in(I =< ground)) is det, in(I =< ground)) is det.
+
+q(P, X) :- P(X).
+
+:- pred q2(pred(T), T).
+:- mode q2(pred(in(I =< apple)) is det, in(I =< apple)) is det.
+
+q2(P, X) :- P(X).
+
+:- pred q3(pred(T), T).
+:- mode q3(pred(in(apple)) is det, in(I =< apple)) is det.
+
+q3(P, X) :- P(X).
+
+%-----------------------------------------------------------------------------%
+% vim: ft=mercury ts=4 sts=4 sw=4 et
diff --git a/tests/valid/Mmakefile b/tests/valid/Mmakefile
index b75e2d1..cdab739 100644
--- a/tests/valid/Mmakefile
+++ b/tests/valid/Mmakefile
@@ -90,6 +90,7 @@ OTHER_PROGS= \
 	constr_inst_syntax \
 	constrained_poly_bound_arg \
 	constrained_poly_insts \
+	constrained_poly_multi \
 	constraint_prop_bug \
 	constructor_arg_names \
 	dcg_test \
diff --git a/tests/valid/constrained_poly_multi.m b/tests/valid/constrained_poly_multi.m
new file mode 100644
index 0000000..c08a308
--- /dev/null
+++ b/tests/valid/constrained_poly_multi.m
@@ -0,0 +1,40 @@
+%-----------------------------------------------------------------------------%
+
+:- module constrained_poly_multi.
+:- interface.
+
+:- pred test is det.
+
+:- pred test2 is failure.
+
+%-----------------------------------------------------------------------------%
+%-----------------------------------------------------------------------------%
+
+:- implementation.
+
+:- type either ---> left ; right.
+:- inst left ---> left.
+:- inst right ---> right.
+
+:- type fruit ---> apple ; orange ; pear.
+:- inst apple ---> apple.
+:- inst orange ---> orange.
+
+:- pred either(either, fruit, fruit, fruit).
+:- mode either(in(left), in(I =< apple), in(J =< orange), out(I =< apple))
+    is det.
+:- mode either(in(right), in(I =< apple), in(J =< orange), out(J =< orange))
+    is det.
+
+either(left, X, _, X).
+either(right, _, X, X).
+
+test :-
+    either(left, apple, orange, apple),
+    either(right, apple, orange, orange).
+
+test2 :-
+    either(left, apple, orange, orange).
+
+%-----------------------------------------------------------------------------%
+% vim: ft=mercury ts=4 sts=4 sw=4 et
-- 
1.8.4



More information about the reviews mailing list