[m-rev.] for review: Optimise modechecking of coerce for large types.

Peter Wang novalazy at gmail.com
Tue Apr 20 15:16:29 AEST 2021


compiler/modecheck_coerce.m:
    Keep a map of previously created bound() insts for types so that we
    can reuse a bound() inst if a type appears multiple times within a
    bigger type.

    Similarly, record the result of testing if a bound() inst is within
    a given type, so that we can avoid repeating the test if the same
    type and inst appears multiple times in the same overall type.

    Convert the list of constructors for a type into a map so that a
    constructor can be found in logarithmic time instead of linear time.
    This is worthwhile for a type with a moderate to high number of
    constructors, as we have to search for the constructor for each
    functor in a bound() inst; the number of such functors is usually
    similar to the number of constructors in the type.

diff --git a/compiler/modecheck_coerce.m b/compiler/modecheck_coerce.m
index 4e5253d37..3cef32f57 100644
--- a/compiler/modecheck_coerce.m
+++ b/compiler/modecheck_coerce.m
@@ -47,6 +47,7 @@
 :- import_module parse_tree.prog_type.
 :- import_module parse_tree.set_of_var.
 
+:- import_module map.
 :- import_module maybe.
 :- import_module require.
 :- import_module set.
@@ -179,6 +180,15 @@ modecheck_coerce_vars(ModuleInfo0, X, Y, InstX, InstY, Res, !ModeInfo) :-
     uniqueness::in, mer_type::in, mer_inst::out) is det.
 
 make_bound_inst_for_type(ModuleInfo, Seen0, Uniq, Type, Inst) :-
+    make_bound_inst_for_type_with_cache(ModuleInfo, Seen0, Uniq, Type, Inst,
+        map.init, _Cache).
+
+:- pred make_bound_inst_for_type_with_cache(module_info::in,
+    set(type_ctor)::in, uniqueness::in, mer_type::in, mer_inst::out,
+    map(mer_type, mer_inst)::in, map(mer_type, mer_inst)::out) is det.
+
+make_bound_inst_for_type_with_cache(ModuleInfo, Seen0, Uniq, Type, Inst,
+        !Cache) :-
     (
         Type = type_variable(_, _),
         Inst = ground_or_unique_inst(Uniq)
@@ -191,10 +201,14 @@ make_bound_inst_for_type(ModuleInfo, Seen0, Uniq, Type, Inst) :-
         % type_ctor
         type_to_ctor_det(Type, TypeCtor),
         ( if set.insert_new(TypeCtor, Seen0, Seen) then
-            % type_constructors substitutes type args into constructors.
-            ( if type_constructors(ModuleInfo, Type, Constructors) then
+            ( if map.search(!.Cache, Type, CachedInst) then
+                Inst = CachedInst
+            else if
+                % type_constructors substitutes type args into constructors.
+                type_constructors(ModuleInfo, Type, Constructors)
+            then
                 constructors_to_bound_insts_rec(ModuleInfo, Seen, TypeCtor,
-                    Uniq, Constructors, BoundInsts0),
+                    Uniq, Constructors, BoundInsts0, !Cache),
                 list.sort_and_remove_dups(BoundInsts0, BoundInsts),
                 % XXX A better approximation of InstResults is probably
                 % possible.
@@ -206,7 +220,11 @@ make_bound_inst_for_type(ModuleInfo, Seen0, Uniq, Type, Inst) :-
                     inst_result_contains_types_unknown,
                     inst_result_no_type_ctor_propagated
                 ),
-                Inst = bound(Uniq, InstResults, BoundInsts)
+                Inst = bound(Uniq, InstResults, BoundInsts),
+                % Remember the result in case the type appears multiple times
+                % in the same overall type, which is quite likely for a large
+                % type.
+                map.det_insert(Type, Inst, !Cache)
             else
                 % Type with no definition, e.g. void
                 Inst = ground_or_unique_inst(Uniq)
@@ -219,8 +237,9 @@ make_bound_inst_for_type(ModuleInfo, Seen0, Uniq, Type, Inst) :-
         Type = tuple_type(ArgTypes, _Kind),
         list.length(ArgTypes, Arity),
         ConsId = tuple_cons(Arity),
-        list.map(make_bound_inst_for_type(ModuleInfo, Seen0, Uniq),
-            ArgTypes, ArgInsts),
+        list.map_foldl(
+            make_bound_inst_for_type_with_cache(ModuleInfo, Seen0, Uniq),
+            ArgTypes, ArgInsts, !Cache),
         BoundInst = bound_functor(ConsId, ArgInsts),
         % XXX A better approximation of InstResults is probably possible.
         InstResults = inst_test_results(
@@ -241,45 +260,50 @@ make_bound_inst_for_type(ModuleInfo, Seen0, Uniq, Type, Inst) :-
         sorry($pred, "apply_n_type")
     ;
         Type = kinded_type(Type1, _),
-        make_bound_inst_for_type(ModuleInfo, Seen0, Uniq, Type1, Inst)
+        make_bound_inst_for_type_with_cache(ModuleInfo, Seen0, Uniq,
+            Type1, Inst, !Cache)
     ).
 
     % Similar to mode_util.constructors_to_bound_insts but also produces
     % bound() insts for the constructor arguments, recursively.
     %
 :- pred constructors_to_bound_insts_rec(module_info::in, set(type_ctor)::in,
-    type_ctor::in, uniqueness::in, list(constructor)::in,
-    list(bound_inst)::out) is det.
+    type_ctor::in, uniqueness::in,
+    list(constructor)::in, list(bound_inst)::out,
+    map(mer_type, mer_inst)::in, map(mer_type, mer_inst)::out) is det.
 
 constructors_to_bound_insts_rec(ModuleInfo, Seen, TypeCtor, Uniq,
-        Constructors, BoundInsts) :-
+        Constructors, BoundInsts, !Cache) :-
     constructors_to_bound_insts_loop_over_ctors(ModuleInfo, Seen, TypeCtor,
-        Uniq, Constructors, BoundInsts).
+        Uniq, Constructors, BoundInsts, !Cache).
 
 :- pred constructors_to_bound_insts_loop_over_ctors(module_info::in,
-    set(type_ctor)::in, type_ctor::in, uniqueness::in, list(constructor)::in,
-    list(bound_inst)::out) is det.
+    set(type_ctor)::in, type_ctor::in, uniqueness::in,
+    list(constructor)::in, list(bound_inst)::out,
+    map(mer_type, mer_inst)::in, map(mer_type, mer_inst)::out) is det.
 
 constructors_to_bound_insts_loop_over_ctors(_ModuleInfo, _Seen, _TypeCtor,
-        _Uniq, [], []).
+        _Uniq, [], [], !Cache).
 constructors_to_bound_insts_loop_over_ctors(ModuleInfo, Seen, TypeCtor,
-        Uniq, [Ctor | Ctors], [BoundInst | BoundInsts]) :-
+        Uniq, [Ctor | Ctors], [BoundInst | BoundInsts], !Cache) :-
     Ctor = ctor(_Ordinal, _MaybeExistConstraints, Name, Args, _Arity, _Ctxt),
-    ctor_arg_list_to_inst_list(ModuleInfo, Seen, Uniq, Args, ArgInsts),
+    ctor_arg_list_to_inst_list(ModuleInfo, Seen, Uniq, Args, ArgInsts, !Cache),
     list.length(ArgInsts, Arity),
     BoundInst = bound_functor(cons(Name, Arity, TypeCtor), ArgInsts),
     constructors_to_bound_insts_loop_over_ctors(ModuleInfo, Seen, TypeCtor,
-        Uniq, Ctors, BoundInsts).
+        Uniq, Ctors, BoundInsts, !Cache).
 
 :- pred ctor_arg_list_to_inst_list(module_info::in, set(type_ctor)::in,
-    uniqueness::in, list(constructor_arg)::in, list(mer_inst)::out) is det.
+    uniqueness::in, list(constructor_arg)::in, list(mer_inst)::out,
+    map(mer_type, mer_inst)::in, map(mer_type, mer_inst)::out) is det.
 
-ctor_arg_list_to_inst_list(_ModuleInfo, _Seen, _Uniq, [], []).
+ctor_arg_list_to_inst_list(_ModuleInfo, _Seen, _Uniq, [], [], !Cache).
 ctor_arg_list_to_inst_list(ModuleInfo, Seen, Uniq,
-        [Arg | Args], [ArgInst | ArgInsts]) :-
+        [Arg | Args], [ArgInst | ArgInsts], !Cache) :-
     Arg = ctor_arg(_MaybeFieldName, ArgType, _Context),
-    make_bound_inst_for_type(ModuleInfo, Seen, Uniq, ArgType, ArgInst),
-    ctor_arg_list_to_inst_list(ModuleInfo, Seen, Uniq, Args, ArgInsts).
+    make_bound_inst_for_type_with_cache(ModuleInfo, Seen, Uniq,
+        ArgType, ArgInst, !Cache),
+    ctor_arg_list_to_inst_list(ModuleInfo, Seen, Uniq, Args, ArgInsts, !Cache).
 
 :- func ground_or_unique_inst(uniqueness) = mer_inst.
 
@@ -292,6 +316,16 @@ ground_or_unique_inst(Uniq) =
 
 %---------------------------------------------------------------------------%
 
+:- type bound_inst_cache == map(type_and_bound_inst, mer_inst).
+
+:- type type_and_bound_inst
+    --->    type_and_bound_inst(
+                mer_type,
+                % These fields are from a bound() inst.
+                uniqueness,
+                list(bound_inst)
+            ).
+
     % Check that a bound() inst only includes functors that are constructors of
     % the given type, recursively. Insts are otherwise assumed to be valid for
     % the type, and not checked to be valid for the type in other respects.
@@ -303,6 +337,15 @@ ground_or_unique_inst(Uniq) =
     mer_inst::in, mer_inst::out) is semidet.
 
 check_bound_functors_in_type(ModuleInfo, Type, Inst0, Inst) :-
+    check_bound_functors_in_type_with_cache(ModuleInfo, Type, Inst0, Inst,
+        map.init, _Cache).
+
+:- pred check_bound_functors_in_type_with_cache(module_info::in, mer_type::in,
+    mer_inst::in, mer_inst::out, bound_inst_cache::in, bound_inst_cache::out)
+    is semidet.
+
+check_bound_functors_in_type_with_cache(ModuleInfo, Type, Inst0, Inst, !Cache)
+        :-
     inst_expand(ModuleInfo, Inst0, Inst1),
     require_complete_switch [Type]
     (
@@ -313,10 +356,12 @@ check_bound_functors_in_type(ModuleInfo, Type, Inst0, Inst) :-
         Inst = Inst1
     ;
         Type = defined_type(_, _, _),
-        check_bound_functors_in_defined_type(ModuleInfo, Type, Inst1, Inst)
+        check_bound_functors_in_defined_type(ModuleInfo, Type, Inst1, Inst,
+            !Cache)
     ;
         Type = tuple_type(ArgTypes, _Kind),
-        check_bound_functors_in_tuple(ModuleInfo, ArgTypes, Inst1, Inst)
+        check_bound_functors_in_tuple(ModuleInfo, ArgTypes, Inst1, Inst,
+            !Cache)
     ;
         Type = higher_order_type(_PredOrFunc, ArgTypes, _HOInstInfo, _Purity,
             _EvalMethod),
@@ -327,33 +372,51 @@ check_bound_functors_in_type(ModuleInfo, Type, Inst0, Inst) :-
         fail
     ;
         Type = kinded_type(Type1, _Kind),
-        check_bound_functors_in_type(ModuleInfo, Type1, Inst1, Inst)
+        check_bound_functors_in_type_with_cache(ModuleInfo, Type1, Inst1, Inst,
+            !Cache)
     ).
 
 %---------------------%
 
 :- pred check_bound_functors_in_defined_type(module_info::in, mer_type::in,
-    mer_inst::in, mer_inst::out) is semidet.
+    mer_inst::in, mer_inst::out, bound_inst_cache::in, bound_inst_cache::out)
+    is semidet.
 
-check_bound_functors_in_defined_type(ModuleInfo, Type, Inst0, Inst) :-
+check_bound_functors_in_defined_type(ModuleInfo, Type, Inst0, Inst, !Cache) :-
     require_complete_switch [Inst0]
     (
         Inst0 = bound(Uniq, _InstResults0, BoundInsts0),
-        type_to_ctor(Type, TypeCtor),
-        % type_constructors substitutes type args into constructors.
-        type_constructors(ModuleInfo, Type, Constructors),
-        list.map(
-            check_bound_functor_in_du_type(ModuleInfo, TypeCtor, Constructors),
-            BoundInsts0, BoundInsts),
-        % XXX A better approximation of InstResults is probably possible.
-        Inst = bound(Uniq, inst_test_no_results, BoundInsts)
+        TypeAndBoundInst0 = type_and_bound_inst(Type, Uniq, BoundInsts0),
+        ( if map.search(!.Cache, TypeAndBoundInst0, CachedInst) then
+            Inst = CachedInst
+        else
+            type_to_ctor(Type, TypeCtor),
+            % type_constructors substitutes type args into constructors.
+            type_constructors(ModuleInfo, Type, Constructors),
+            list.foldl(build_ctors_map, Constructors, map.init, CtorsMap),
+            list.map_foldl(
+                check_bound_functor_in_du_type(ModuleInfo, TypeCtor, CtorsMap),
+                BoundInsts0, BoundInsts, !Cache),
+            % XXX A better approximation of InstResults is probably possible.
+            Inst = bound(Uniq, inst_test_no_results, BoundInsts),
+            % Remember that the bound functors in Inst0 are in Type.
+            % This saves repeating the same test if a type (and inst) appears
+            % multiple times in the same overall type, which is quite likely
+            % for a large type. We use map.insert in case a recursive call
+            % inserted the result first.
+            ( if map.insert(TypeAndBoundInst0, Inst, !Cache) then
+                true
+            else
+                true
+            )
+        )
     ;
         Inst0 = ground(_Uniq, _HOInstInfo),
         Inst = Inst0
     ;
         Inst0 = constrained_inst_vars(InstVars, SubInst0),
         check_bound_functors_in_defined_type(ModuleInfo, Type,
-            SubInst0, SubInst),
+            SubInst0, SubInst, !Cache),
         Inst = constrained_inst_vars(InstVars, SubInst)
     ;
         ( Inst0 = free
@@ -369,64 +432,76 @@ check_bound_functors_in_defined_type(ModuleInfo, Type, Inst0, Inst) :-
         unexpected($pred, "unexpanded inst")
     ).
 
+%---------------------%
+
+:- type ctors_map == map(ctor_name_and_arity, constructor).
+
+:- type ctor_name_and_arity
+    --->    ctor_name_and_arity(string, int).
+
+:- pred build_ctors_map(constructor::in, ctors_map::in, ctors_map::out) is det.
+
+build_ctors_map(Ctor, !CtorsMap) :-
+    Ctor = ctor(_, _, ConsName, _, Arity, _),
+    unqualify_name(ConsName) = Name,
+    NameArity = ctor_name_and_arity(Name, Arity),
+    map.det_insert(NameArity, Ctor, !CtorsMap).
+
+%---------------------%
+
 :- pred check_bound_functor_in_du_type(module_info::in, type_ctor::in,
-    list(constructor)::in, bound_inst::in, bound_inst::out) is semidet.
+    ctors_map::in, bound_inst::in, bound_inst::out,
+    bound_inst_cache::in, bound_inst_cache::out) is semidet.
 
-check_bound_functor_in_du_type(ModuleInfo, TypeCtor, Ctors,
-        BoundInst0, BoundInst) :-
+check_bound_functor_in_du_type(ModuleInfo, TypeCtor, CtorsMap,
+        BoundInst0, BoundInst, !Cache) :-
     BoundInst0 = bound_functor(ConsId0, ArgInsts0),
     ConsId0 = cons(SymName0, Arity, _TypeCtor0),
     Name = unqualify_name(SymName0),
-    find_first_matching_constructor_unqual(Name, Arity, Ctors, MatchingCtor),
+    NameArity = ctor_name_and_arity(Name, Arity),
+    map.search(CtorsMap, NameArity, MatchingCtor),
+
     MatchingCtor = ctor(_, _, SymName, CtorArgs, _, _),
     ConsId = cons(SymName, Arity, TypeCtor),
     check_bound_functors_in_ctor_args(ModuleInfo, CtorArgs,
-        ArgInsts0, ArgInsts),
+        ArgInsts0, ArgInsts, !Cache),
     BoundInst = bound_functor(ConsId, ArgInsts).
 
-:- pred find_first_matching_constructor_unqual(string::in, int::in,
-    list(constructor)::in, constructor::out) is semidet.
-
-find_first_matching_constructor_unqual(Name, Arity, [Ctor | Ctors],
-        MatchingCtor) :-
-    ( if
-        Ctor = ctor(_, _, ConsName, _, Arity, _),
-        unqualify_name(ConsName) = Name
-    then
-        MatchingCtor = Ctor
-    else
-        find_first_matching_constructor_unqual(Name, Arity, Ctors, MatchingCtor)
-    ).
-
 :- pred check_bound_functors_in_ctor_args(module_info::in,
-    list(constructor_arg)::in, list(mer_inst)::in, list(mer_inst)::out)
-    is semidet.
+    list(constructor_arg)::in, list(mer_inst)::in, list(mer_inst)::out,
+    bound_inst_cache::in, bound_inst_cache::out) is semidet.
 
-check_bound_functors_in_ctor_args(_, [], [], []).
+check_bound_functors_in_ctor_args(_, [], [], [], !Cache).
 check_bound_functors_in_ctor_args(ModuleInfo, [CtorArg | CtorArgs],
-        [ArgInst0 | ArgInsts0], [ArgInst | ArgInsts]) :-
-    check_bound_functors_in_ctor_arg(ModuleInfo, CtorArg, ArgInst0, ArgInst),
+        [ArgInst0 | ArgInsts0], [ArgInst | ArgInsts], !Cache) :-
+    check_bound_functors_in_ctor_arg(ModuleInfo, CtorArg, ArgInst0, ArgInst,
+        !Cache),
     check_bound_functors_in_ctor_args(ModuleInfo, CtorArgs,
-        ArgInsts0, ArgInsts).
+        ArgInsts0, ArgInsts, !Cache).
 
 :- pred check_bound_functors_in_ctor_arg(module_info::in, constructor_arg::in,
-    mer_inst::in, mer_inst::out) is semidet.
+    mer_inst::in, mer_inst::out, bound_inst_cache::in, bound_inst_cache::out)
+    is semidet.
 
-check_bound_functors_in_ctor_arg(ModuleInfo, CtorArg, ArgInst0, ArgInst) :-
+check_bound_functors_in_ctor_arg(ModuleInfo, CtorArg, ArgInst0, ArgInst,
+        !Cache) :-
     CtorArg = ctor_arg(_MaybeFieldName, ArgType, _Context),
-    check_bound_functors_in_type(ModuleInfo, ArgType, ArgInst0, ArgInst).
+    check_bound_functors_in_type_with_cache(ModuleInfo, ArgType,
+        ArgInst0, ArgInst, !Cache).
 
 %---------------------%
 
 :- pred check_bound_functors_in_tuple(module_info::in, list(mer_type)::in,
-    mer_inst::in, mer_inst::out) is semidet.
+    mer_inst::in, mer_inst::out, bound_inst_cache::in, bound_inst_cache::out)
+    is semidet.
 
-check_bound_functors_in_tuple(ModuleInfo, ArgTypes, Inst0, Inst) :-
+check_bound_functors_in_tuple(ModuleInfo, ArgTypes, Inst0, Inst, !Cache) :-
     require_complete_switch [Inst0]
     (
         Inst0 = bound(Uniq, _InstTestsResults, BoundInsts0),
-        list.map(bound_check_bound_functors_in_tuple(ModuleInfo, ArgTypes),
-            BoundInsts0, BoundInsts),
+        list.map_foldl(
+            bound_check_bound_functors_in_tuple(ModuleInfo, ArgTypes),
+            BoundInsts0, BoundInsts, !Cache),
         Inst = bound(Uniq, inst_test_no_results, BoundInsts)
     ;
         Inst0 = ground(_Uniq, _HOInstInfo),
@@ -449,25 +524,29 @@ check_bound_functors_in_tuple(ModuleInfo, ArgTypes, Inst0, Inst) :-
     ).
 
 :- pred bound_check_bound_functors_in_tuple(module_info::in,
-    list(mer_type)::in, bound_inst::in, bound_inst::out) is semidet.
+    list(mer_type)::in, bound_inst::in, bound_inst::out,
+    bound_inst_cache::in, bound_inst_cache::out) is semidet.
 
 bound_check_bound_functors_in_tuple(ModuleInfo, ArgTypes,
-        BoundInst0, BoundInst) :-
+        BoundInst0, BoundInst, !Cache) :-
     BoundInst0 = bound_functor(ConsId, ArgInsts0),
     list.length(ArgTypes, Arity),
     ConsId = tuple_cons(Arity),
     check_bound_functors_in_tuple_args(ModuleInfo, ArgTypes,
-        ArgInsts0, ArgInsts),
+        ArgInsts0, ArgInsts, !Cache),
     BoundInst = bound_functor(ConsId, ArgInsts).
 
 :- pred check_bound_functors_in_tuple_args(module_info::in, list(mer_type)::in,
-    list(mer_inst)::in, list(mer_inst)::out) is semidet.
+    list(mer_inst)::in, list(mer_inst)::out,
+    bound_inst_cache::in, bound_inst_cache::out) is semidet.
 
-check_bound_functors_in_tuple_args(_ModuleInfo, [], [], []).
+check_bound_functors_in_tuple_args(_ModuleInfo, [], [], [], !Cache).
 check_bound_functors_in_tuple_args(ModuleInfo,
-        [Type | Types], [Inst0 | Insts0], [Inst | Insts]) :-
-    check_bound_functors_in_type(ModuleInfo, Type, Inst0, Inst),
-    check_bound_functors_in_tuple_args(ModuleInfo, Types, Insts0, Insts).
+        [Type | Types], [Inst0 | Insts0], [Inst | Insts], !Cache) :-
+    check_bound_functors_in_type_with_cache(ModuleInfo, Type, Inst0, Inst,
+        !Cache),
+    check_bound_functors_in_tuple_args(ModuleInfo, Types, Insts0, Insts,
+        !Cache).
 
 %---------------------%
 
-- 
2.30.0



More information about the reviews mailing list