[m-rev.] for review: Make subtypes inherit user-defined equality/comparison from base type.

Peter Wang novalazy at gmail.com
Mon Apr 12 16:55:47 AEST 2021


Since values of a subtype are tested for equality or compared
by first converting to the base type then testing/comparing,
it means that subtypes implicitly inherited any user-defined
equality/comparison from the base type. Allowing a subtype to also
define its own equality/comparison predicates seems more confusing
than useful, so disallow it.

compiler/prog_data.m:
    Add a option noncanon_subtype to maybe_canonical.

compiler/add_type.m:
    Report an error if a subtype defined in the current module
    has user-defined equality or comparison.

    Set the maybe_canonical field of hlds_type_defn for a subtype if the
    if the base type is noncanonical. Setting it to noncanon() makes the
    type_ctor_info have a MR_TypeCtorRep value of
    MR_TYPECTOR_REP_*_USEREQ.

compiler/unify_proc.m:
    Handle subtypes ahead of (other) noncanonical types when generating
    unify or compare proc clauses.

compiler/dead_proc_elim.m:
compiler/hlds_out_module.m:
compiler/intermod.m:
compiler/post_term_analysis.m:
compiler/special_pred.m:
    Conform to changes.

doc/reference_manual.texi:
    Document change.

tests/hard_coded/Mmakefile:
tests/hard_coded/subtype_user_compare.exp:
tests/hard_coded/subtype_user_compare.m:
tests/hard_coded/subtype_user_compare2.m:
tests/invalid/Mmakefile:
tests/invalid/subtype_user_compare.err_exp:
tests/invalid/subtype_user_compare.m:
    Add test cases.

diff --git a/compiler/add_type.m b/compiler/add_type.m
index f2a7d5f47..9078fbc27 100644
--- a/compiler/add_type.m
+++ b/compiler/add_type.m
@@ -79,7 +79,6 @@
 :- import_module set.
 :- import_module string.
 :- import_module term.
-:- import_module unit.
 
 :- inst hlds_type_body_du for hlds_type_body/0
     --->    hlds_du_type(ground, ground, ground, ground, ground).
@@ -331,8 +330,8 @@ module_add_type_defn_mercury(TypeStatus1, TypeCtor, TypeParams,
     ),
     (
         ParseTreeTypeDefn = parse_tree_du_type(DetailsDu),
-        check_for_dummy_type_with_unify_compare(TypeStatus, TypeCtor,
-            DetailsDu, Context, !FoundInvalidType, !Specs)
+        check_for_invalid_user_defined_unify_compare(TypeStatus,
+            TypeCtor, DetailsDu, Context, !FoundInvalidType, !Specs)
     ;
         ParseTreeTypeDefn = parse_tree_eqv_type(DetailsEqv),
         check_for_polymorphic_eqv_type_with_monomorphic_body(TypeStatus,
@@ -628,47 +627,53 @@ maybe_report_multiple_def_error(TypeStatus, TypeCtor, Context, OldDefn,
 
 %---------------------%
 
-:- pred check_for_dummy_type_with_unify_compare(type_status::in,
+:- pred check_for_invalid_user_defined_unify_compare(type_status::in,
     type_ctor::in, type_details_du::in, prog_context::in,
     found_invalid_type::in, found_invalid_type::out,
     list(error_spec)::in, list(error_spec)::out) is det.
 
-check_for_dummy_type_with_unify_compare(TypeStatus, TypeCtor, DetailsDu,
+check_for_invalid_user_defined_unify_compare(TypeStatus, TypeCtor, DetailsDu,
         Context, !FoundInvalidType, !Specs) :-
     ( if
-        % Discriminated unions whose definition consists of a single
-        % zero-arity constructor are dummy types. Dummy types are not allowed
-        % to have user-defined equality or comparison.
         DetailsDu = type_details_du(MaybeSuperType, Ctors, MaybeCanonical,
             _MaybeDirectArg),
+        MaybeCanonical = noncanon(_),
         (
-            MaybeSuperType = no
+            MaybeSuperType = no,
+            % Discriminated unions whose definition consists of a single
+            % zero-arity constructor are dummy types. Dummy types are not
+            % allowed to have user-defined equality or comparison.
+            Ctors = one_or_more(Ctor, []),
+            Ctor ^ cons_args = []
         ;
             MaybeSuperType = yes(_)
-            % XXX SUBTYPE A subtype with a single zero-arity constructor
-            % is not necessarily a dummy type, so this check will incorrectly
-            % prevent such a subtype from having user-defined equality or
-            % comparison (however unlikely that would be).
         ),
-        Ctors = one_or_more(Ctor, []),
-        Ctor ^ cons_args = [],
-        MaybeCanonical = noncanon(_),
         % Only report errors for types defined in this module.
         type_status_defined_in_this_module(TypeStatus) = yes
     then
-        DummyMainPieces = [words("Error: the type"),
-            unqual_type_ctor(TypeCtor), words("contains no information,"),
-            words("and as such it is not allowed to have"),
-            words("user-defined equality or comparison."), nl],
-        DummyVerbosePieces = [words("Discriminated union types"),
-            words("whose body consists of a single zero-arity constructor"),
-            words("cannot have user-defined equality or comparison."), nl],
-        DummyMsg = simple_msg(Context,
-            [always(DummyMainPieces),
-            verbose_only(verbose_once, DummyVerbosePieces)]),
-        DummySpec = error_spec($pred, severity_error, phase_parse_tree_to_hlds,
-            [DummyMsg]),
-        !:Specs = [DummySpec | !.Specs],
+        (
+            MaybeSuperType = no,
+            MainPieces = [words("Error: the type"),
+                unqual_type_ctor(TypeCtor), words("contains no information,"),
+                words("and as such it is not allowed to have"),
+                words("user-defined equality or comparison."), nl],
+            VerbosePieces = [words("Discriminated union types"),
+                words("whose body consists of a single zero-arity constructor"),
+                words("cannot have user-defined equality or comparison."), nl],
+            Msg = simple_msg(Context,
+                [always(MainPieces),
+                verbose_only(verbose_once, VerbosePieces)])
+        ;
+            MaybeSuperType = yes(_),
+            Pieces = [words("Error: the subtype"),
+                unqual_type_ctor(TypeCtor),
+                words("is not allowed to have"),
+                words("user-defined equality or comparison."), nl],
+            Msg = simplest_msg(Context, Pieces)
+        ),
+        Spec = error_spec($pred, severity_error, phase_parse_tree_to_hlds,
+            [Msg]),
+        !:Specs = [Spec | !.Specs],
         !:FoundInvalidType = found_invalid_type
     else
         true
@@ -926,7 +931,21 @@ add_du_ctors_check_subtype_check_foreign_type(TypeTable, TypeCtor, TypeDefn,
         (
             MaybeSuperType = yes(SuperType),
             check_subtype_defn(TypeTable, TVarSet, TypeCtor, TypeDefn, Body,
-                SuperType, !FoundInvalidType, !Specs)
+                SuperType, MaybeSetSubtypeNoncanon, !FoundInvalidType, !Specs),
+            (
+                MaybeSetSubtypeNoncanon = do_not_set_subtype_noncanon
+            ;
+                MaybeSetSubtypeNoncanon = set_subtype_noncanon,
+                % Set noncanonical flag on subtype definition if the base type
+                % is noncanonical.
+                NoncanonBody = Body ^ du_type_canonical :=
+                    noncanon(noncanon_subtype),
+                set_type_defn_body(NoncanonBody, TypeDefn, NoncanonTypeDefn),
+                module_info_get_type_table(!.ModuleInfo, TypeTable0),
+                replace_type_ctor_defn(TypeCtor, NoncanonTypeDefn,
+                    TypeTable0, TypeTable1),
+                module_info_set_type_table(TypeTable1, !ModuleInfo)
+            )
         ;
             MaybeSuperType = no
         ),
@@ -1206,13 +1225,18 @@ check_foreign_type_for_current_target(TypeCtor, ForeignTypeBody, PrevErrors,
 %---------------------------------------------------------------------------%
 %---------------------------------------------------------------------------%
 
+:- type maybe_set_subtype_noncanonical
+    --->    do_not_set_subtype_noncanon
+    ;       set_subtype_noncanon.
+
 :- pred check_subtype_defn(type_table::in, tvarset::in, type_ctor::in,
     hlds_type_defn::in, hlds_type_body::in(hlds_type_body_du), mer_type::in,
+    maybe_set_subtype_noncanonical::out,
     found_invalid_type::in, found_invalid_type::out,
     list(error_spec)::in, list(error_spec)::out) is det.
 
 check_subtype_defn(TypeTable, TVarSet, TypeCtor, TypeDefn, TypeBody, SuperType,
-        !FoundInvalidType, !Specs) :-
+        MaybeSetSubtypeNoncanon, !FoundInvalidType, !Specs) :-
     hlds_data.get_type_defn_status(TypeDefn, OrigTypeStatus),
     hlds_data.get_type_defn_context(TypeDefn, Context),
     ( if type_to_ctor_and_args(SuperType, SuperTypeCtor, SuperTypeArgs) then
@@ -1223,14 +1247,15 @@ check_subtype_defn(TypeTable, TVarSet, TypeCtor, TypeDefn, TypeBody, SuperType,
             SearchRes = ok(SuperTypeDefn),
             check_subtype_defn_2(TypeTable, TypeCtor, TypeDefn, TypeBody,
                 SuperTypeCtor, SuperTypeDefn, SuperTypeArgs, Seen1,
-                !FoundInvalidType, !Specs)
+                MaybeSetSubtypeNoncanon, !FoundInvalidType, !Specs)
         ;
             SearchRes = error(Error),
             Pieces = supertype_ctor_defn_error_pieces(SuperTypeCtor, Error),
             Spec = simplest_spec($pred, severity_error,
                 phase_parse_tree_to_hlds, Context, Pieces),
             !:Specs = [Spec | !.Specs],
-            !:FoundInvalidType = found_invalid_type
+            !:FoundInvalidType = found_invalid_type,
+            MaybeSetSubtypeNoncanon = do_not_set_subtype_noncanon
         )
     else
         SuperTypeStr = mercury_type_to_string(TVarSet, print_name_only,
@@ -1241,18 +1266,20 @@ check_subtype_defn(TypeTable, TVarSet, TypeCtor, TypeDefn, TypeBody, SuperType,
         Spec = simplest_spec($pred, severity_error, phase_parse_tree_to_hlds,
             Context, Pieces),
         !:Specs = [Spec | !.Specs],
-        !:FoundInvalidType = found_invalid_type
+        !:FoundInvalidType = found_invalid_type,
+        MaybeSetSubtypeNoncanon = do_not_set_subtype_noncanon
     ).
 
 :- pred check_subtype_defn_2(type_table::in,
     type_ctor::in, hlds_type_defn::in, hlds_type_body::in(hlds_type_body_du),
     type_ctor::in, hlds_type_defn::in, list(mer_type)::in, set(type_ctor)::in,
+    maybe_set_subtype_noncanonical::out,
     found_invalid_type::in, found_invalid_type::out,
     list(error_spec)::in, list(error_spec)::out) is det.
 
 check_subtype_defn_2(TypeTable, TypeCtor, TypeDefn, TypeBody,
         SuperTypeCtor, SuperTypeDefn, SuperTypeArgs, Seen0,
-        !FoundInvalidType, !Specs) :-
+        MaybeSetSubtypeNoncanon, !FoundInvalidType, !Specs) :-
     hlds_data.get_type_defn_context(TypeDefn, Context),
     hlds_data.get_type_defn_body(SuperTypeDefn, SuperTypeBody),
     (
@@ -1261,7 +1288,7 @@ check_subtype_defn_2(TypeTable, TypeCtor, TypeDefn, TypeBody,
             IsForeign = no,
             check_subtype_defn_3(TypeTable, TypeCtor, TypeDefn, TypeBody,
                 SuperTypeCtor, SuperTypeDefn, SuperTypeBody, SuperTypeArgs,
-                Seen0, !FoundInvalidType, !Specs)
+                Seen0, MaybeSetSubtypeNoncanon, !FoundInvalidType, !Specs)
         ;
             IsForeign = yes(_),
             Pieces = [words("Error:"), unqual_type_ctor(SuperTypeCtor),
@@ -1270,7 +1297,8 @@ check_subtype_defn_2(TypeTable, TypeCtor, TypeDefn, TypeBody,
             Spec = simplest_spec($pred, severity_error,
                 phase_parse_tree_to_hlds, Context, Pieces),
             !:Specs = [Spec | !.Specs],
-            !:FoundInvalidType = found_invalid_type
+            !:FoundInvalidType = found_invalid_type,
+            MaybeSetSubtypeNoncanon = do_not_set_subtype_noncanon
         )
     ;
         ( SuperTypeBody = hlds_eqv_type(_)
@@ -1285,24 +1313,33 @@ check_subtype_defn_2(TypeTable, TypeCtor, TypeDefn, TypeBody,
         Spec = simplest_spec($pred, severity_error, phase_parse_tree_to_hlds,
             Context, Pieces),
         !:Specs = [Spec | !.Specs],
-        !:FoundInvalidType = found_invalid_type
+        !:FoundInvalidType = found_invalid_type,
+        MaybeSetSubtypeNoncanon = do_not_set_subtype_noncanon
     ).
 
 :- pred check_subtype_defn_3(type_table::in,
     type_ctor::in, hlds_type_defn::in, hlds_type_body::in(hlds_type_body_du),
     type_ctor::in, hlds_type_defn::in, hlds_type_body::in(hlds_type_body_du),
     list(mer_type)::in, set(type_ctor)::in,
+    maybe_set_subtype_noncanonical::out,
     found_invalid_type::in, found_invalid_type::out,
     list(error_spec)::in, list(error_spec)::out) is det.
 
 check_subtype_defn_3(TypeTable, TypeCtor, TypeDefn, TypeBody,
         SuperTypeCtor, SuperTypeDefn, SuperTypeBody, SuperTypeArgs,
-        Seen0, !FoundInvalidType, !Specs) :-
+        Seen0, MaybeSetSubtypeNoncanon, !FoundInvalidType, !Specs) :-
     hlds_data.get_type_defn_status(TypeDefn, TypeStatus),
     check_subtype_has_base_type(TypeTable, TypeStatus, SuperTypeCtor,
         SuperTypeDefn, MaybeBaseTypeError, Seen0, _Seen),
     (
-        MaybeBaseTypeError = ok(unit),
+        MaybeBaseTypeError = ok(BaseMaybeCanonical),
+        (
+            BaseMaybeCanonical = canon,
+            MaybeSetSubtypeNoncanon = do_not_set_subtype_noncanon
+        ;
+            BaseMaybeCanonical = noncanon(_),
+            MaybeSetSubtypeNoncanon = set_subtype_noncanon
+        ),
         check_subtype_ctors(TypeTable, TypeCtor, TypeDefn, TypeBody,
             SuperTypeCtor, SuperTypeDefn, SuperTypeBody, SuperTypeArgs,
             !FoundInvalidType, !Specs)
@@ -1317,25 +1354,27 @@ check_subtype_defn_3(TypeTable, TypeCtor, TypeDefn, TypeBody,
                 phase_parse_tree_to_hlds, Context, Pieces),
             !:Specs = [Spec | !.Specs]
         ),
-        !:FoundInvalidType = found_invalid_type
+        !:FoundInvalidType = found_invalid_type,
+        MaybeSetSubtypeNoncanon = do_not_set_subtype_noncanon
     ).
 
 %---------------------%
 
 :- pred check_subtype_has_base_type(type_table::in, type_status::in,
     type_ctor::in, hlds_type_defn::in,
-    maybe_error(unit, list(format_component))::out,
+    maybe_error(maybe_canonical, list(format_component))::out,
     set(type_ctor)::in, set(type_ctor)::out) is det.
 
 check_subtype_has_base_type(TypeTable, OrigTypeStatus, CurTypeCtor,
         CurTypeDefn, MaybeError, !Seen) :-
     hlds_data.get_type_defn_body(CurTypeDefn, CurTypeBody),
     (
-        CurTypeBody = hlds_du_type(_, MaybeSuperType, _, _, IsForeign),
+        CurTypeBody = hlds_du_type(_, MaybeSuperType, MaybeCanonical, _,
+            IsForeign),
         (
             IsForeign = no,
             MaybeSuperType = no,
-            MaybeError = ok(unit)
+            MaybeError = ok(MaybeCanonical)
         ;
             IsForeign = no,
             MaybeSuperType = yes(SuperType),
diff --git a/compiler/dead_proc_elim.m b/compiler/dead_proc_elim.m
index 9ff18181d..a26aab937 100644
--- a/compiler/dead_proc_elim.m
+++ b/compiler/dead_proc_elim.m
@@ -1434,6 +1434,8 @@ dead_pred_initialize_maybe_canonical(ModuleInfo, MaybeCanon, !NeededPreds) :-
             set_tree234.insert_list(CmpPredIds, !NeededPreds)
         ;
             NonCanonical = noncanon_abstract(_IsSolverType)
+        ;
+            NonCanonical = noncanon_subtype
         )
     ).
 
diff --git a/compiler/hlds_out_module.m b/compiler/hlds_out_module.m
index 2a5c641d2..2a355338f 100644
--- a/compiler/hlds_out_module.m
+++ b/compiler/hlds_out_module.m
@@ -583,6 +583,9 @@ maybe_canonical_to_simple_string(MaybeCanonical) = String :-
                 IsSolver = solver_type,
                 String = "noncanon_abstract_solver"
             )
+        ;
+            NonCanonical = noncanon_subtype,
+            String = "noncanon_subtype"
         )
     ).
 
diff --git a/compiler/intermod.m b/compiler/intermod.m
index 73646388e..f466d94ff 100644
--- a/compiler/intermod.m
+++ b/compiler/intermod.m
@@ -1297,7 +1297,9 @@ resolve_unify_compare_overloading(ModuleInfo, TypeCtor,
     ;
         MaybeCanonical0 = noncanon(NonCanonical0),
         (
-            NonCanonical0 = noncanon_abstract(_IsSolverType),
+            ( NonCanonical0 = noncanon_abstract(_IsSolverType)
+            ; NonCanonical0 = noncanon_subtype
+            ),
             MaybeCanonical = MaybeCanonical0
         ;
             NonCanonical0 = noncanon_uni_cmp(Uni0, Cmp0),
diff --git a/compiler/parse_tree_out.m b/compiler/parse_tree_out.m
index 5f3b8b153..316343e9d 100644
--- a/compiler/parse_tree_out.m
+++ b/compiler/parse_tree_out.m
@@ -1430,7 +1430,9 @@ mercury_output_where_attributes(Info, TypeVarSet, MaybeSolverTypeDetails,
         MaybeCanonical, MaybeDirectArgs, Stream, !IO) :-
     ( if
         MaybeSolverTypeDetails = no,
-        MaybeCanonical = canon,
+        ( MaybeCanonical = canon
+        ; MaybeCanonical = noncanon(noncanon_subtype)
+        ),
         MaybeDirectArgs = no
     then
         true
@@ -1441,6 +1443,10 @@ mercury_output_where_attributes(Info, TypeVarSet, MaybeSolverTypeDetails,
             MaybeUniPred = no,
             MaybeCmpPred = no,
             io.write_string(Stream, "type_is_abstract_noncanonical", !IO)
+        ;
+            MaybeCanonical = noncanon(noncanon_subtype),
+            MaybeUniPred = no,
+            MaybeCmpPred = no
         ;
             (
                 MaybeCanonical = canon,
diff --git a/compiler/post_term_analysis.m b/compiler/post_term_analysis.m
index 7aae59f50..77f806299 100644
--- a/compiler/post_term_analysis.m
+++ b/compiler/post_term_analysis.m
@@ -239,6 +239,9 @@ special_pred_needs_term_check(ModuleInfo, SpecialPredId, TypeDefn)
         ;
             NonCanonical = noncanon_abstract(_),
             unexpected($pred, "type is local yet it is noncanon_abstract")
+        ;
+            NonCanonical = noncanon_subtype,
+            NeedsTermCheck = no
         )
     ;
         MaybeCanonical = canon,
diff --git a/compiler/prog_data.m b/compiler/prog_data.m
index 7ef7001c3..7ac83717a 100644
--- a/compiler/prog_data.m
+++ b/compiler/prog_data.m
@@ -722,13 +722,18 @@ cons_id_is_const_struct(ConsId, ConstNum) :-
 :- func arg_pos_width_to_width_only(arg_pos_width) = arg_width.
 
     % The noncanon functor gives the user-defined unification and/or comparison
-    % predicates for a noncanonical type, if they are known. The value
-    % noncanon_abstract represents a type whose definition uses the syntax
-    % `where type_is_abstract_noncanonical' and has been read from an
+    % predicates for a noncanonical type, if they are known.
+    %
+    % The value noncanon_abstract represents a type whose definition uses the
+    % syntax `where type_is_abstract_noncanonical' and has been read from an
     % .int2 file. This means we know that the type has a noncanonical
     % representation, but we don't know what the unification or comparison
     % predicates are.
     %
+    % The value noncanon_subtype represents a subtype whose base type has
+    % user-defined unification and/or comparison predicates. Subtypes cannot
+    % have their own user-defined unification/comparison predicates.
+    %
 :- type maybe_canonical
     --->    canon
     ;       noncanon(noncanonical).
@@ -737,7 +742,8 @@ cons_id_is_const_struct(ConsId, ConstNum) :-
     --->    noncanon_uni_cmp(equality_pred, comparison_pred)
     ;       noncanon_uni_only(equality_pred)
     ;       noncanon_cmp_only(comparison_pred)
-    ;       noncanon_abstract(is_solver_type).
+    ;       noncanon_abstract(is_solver_type)
+    ;       noncanon_subtype.
 
     % The `where' attributes of a solver type definition must begin
     % with
diff --git a/compiler/special_pred.m b/compiler/special_pred.m
index e80c67535..f455022ad 100644
--- a/compiler/special_pred.m
+++ b/compiler/special_pred.m
@@ -339,7 +339,9 @@ special_pred_for_type_needs_typecheck(ModuleInfo, SpecialPredId, TypeBody) :-
                 ; NonCanonical = noncanon_cmp_only(_)
                 )
             ;
-                NonCanonical = noncanon_abstract(_),
+                ( NonCanonical = noncanon_abstract(_)
+                ; NonCanonical = noncanon_subtype
+                ),
                 fail
             )
         ;
@@ -354,6 +356,7 @@ special_pred_for_type_needs_typecheck(ModuleInfo, SpecialPredId, TypeBody) :-
             ;
                 ( NonCanonical = noncanon_uni_only(_)
                 ; NonCanonical = noncanon_abstract(_)
+                ; NonCanonical = noncanon_subtype
                 ),
                 fail
             )
diff --git a/compiler/unify_proc.m b/compiler/unify_proc.m
index bbcb09da2..1f12c0857 100644
--- a/compiler/unify_proc.m
+++ b/compiler/unify_proc.m
@@ -180,10 +180,20 @@ generate_unify_proc_body(SpecDefnInfo, X, Y, Clauses, !Info) :-
     info_get_module_info(!.Info, ModuleInfo),
     TypeBody = SpecDefnInfo ^ spdi_type_body,
     Context = SpecDefnInfo ^ spdi_context,
-    % We used to special-case the type_ctors for which
-    % is_type_ctor_a_builtin_dummy(TypeCtor) = is_builtin_dummy_type_ctor,
-    % but both those types now have user-defined unify and compare preds.
     ( if
+        TypeBody = hlds_du_type(_, yes(SuperType), _, _, _)
+    then
+        % Unify subtype terms after casting to base type.
+        % This is necessary in high-level data grades,
+        % and saves some code in low-level data grades.
+        TVarSet = SpecDefnInfo ^ spdi_tvarset,
+        get_du_base_type(ModuleInfo, TVarSet, SuperType, BaseType),
+        generate_unify_proc_body_eqv(Context, BaseType, X, Y, Clause, !Info),
+        Clauses = [Clause]
+    else if
+        % We used to special-case the type_ctors for which
+        % is_type_ctor_a_builtin_dummy(TypeCtor) = is_builtin_dummy_type_ctor,
+        % but both those types now have user-defined unify and compare preds.
         type_body_has_user_defined_equality_pred(ModuleInfo,
             TypeBody, UserEqComp)
     then
@@ -234,59 +244,47 @@ generate_unify_proc_body(SpecDefnInfo, X, Y, Clauses, !Info) :-
             Clauses = [Clause]
         ;
             TypeBody = hlds_du_type(_, MaybeSuperType, _, MaybeRepn, _),
+            expect(unify(MaybeSuperType, no), $pred, "MaybeSuperType != no"),
             (
                 MaybeRepn = no,
                 unexpected($pred, "MaybeRepn = no")
             ;
                 MaybeRepn = yes(Repn)
             ),
-            ( if
-                MaybeSuperType = yes(SuperType),
-                TVarSet = SpecDefnInfo ^ spdi_tvarset,
-                get_du_base_type(ModuleInfo, TVarSet, SuperType, BaseType)
-            then
-                % Unify after casting to base type.
-                % This is necessary in high-level data grades,
-                % and saves some code in low-level data grades.
-                generate_unify_proc_body_eqv(Context, BaseType, X, Y, Clause,
-                    !Info),
+            DuTypeKind = Repn ^ dur_kind,
+            (
+                ( DuTypeKind = du_type_kind_mercury_enum
+                ; DuTypeKind = du_type_kind_foreign_enum(_)
+                ),
+                generate_unify_proc_body_enum(Context, X, Y,
+                    Clause, !Info),
                 Clauses = [Clause]
-            else
-                DuTypeKind = Repn ^ dur_kind,
+            ;
+                DuTypeKind = du_type_kind_direct_dummy,
+                generate_unify_proc_body_dummy(Context, X, Y,
+                    Clause, !Info),
+                Clauses = [Clause]
+            ;
+                DuTypeKind = du_type_kind_notag(_, ArgType, _),
+                ArgIsDummy = is_type_a_dummy(ModuleInfo, ArgType),
                 (
-                    ( DuTypeKind = du_type_kind_mercury_enum
-                    ; DuTypeKind = du_type_kind_foreign_enum(_)
-                    ),
-                    generate_unify_proc_body_enum(Context, X, Y,
-                        Clause, !Info),
-                    Clauses = [Clause]
-                ;
-                    DuTypeKind = du_type_kind_direct_dummy,
+                    ArgIsDummy = is_dummy_type,
+                    % Treat this type as if it were a dummy type
+                    % itself.
                     generate_unify_proc_body_dummy(Context, X, Y,
                         Clause, !Info),
                     Clauses = [Clause]
                 ;
-                    DuTypeKind = du_type_kind_notag(_, ArgType, _),
-                    ArgIsDummy = is_type_a_dummy(ModuleInfo, ArgType),
-                    (
-                        ArgIsDummy = is_dummy_type,
-                        % Treat this type as if it were a dummy type
-                        % itself.
-                        generate_unify_proc_body_dummy(Context, X, Y,
-                            Clause, !Info),
-                        Clauses = [Clause]
-                    ;
-                        ArgIsDummy = is_not_dummy_type,
-                        CtorRepns = Repn ^ dur_ctor_repns,
-                        generate_unify_proc_body_du(SpecDefnInfo,
-                            CtorRepns, X, Y, Clauses, !Info)
-                    )
-                ;
-                    DuTypeKind = du_type_kind_general,
+                    ArgIsDummy = is_not_dummy_type,
                     CtorRepns = Repn ^ dur_ctor_repns,
                     generate_unify_proc_body_du(SpecDefnInfo,
                         CtorRepns, X, Y, Clauses, !Info)
                 )
+            ;
+                DuTypeKind = du_type_kind_general,
+                CtorRepns = Repn ^ dur_ctor_repns,
+                generate_unify_proc_body_du(SpecDefnInfo,
+                    CtorRepns, X, Y, Clauses, !Info)
             )
         )
     ).
@@ -312,6 +310,9 @@ generate_unify_proc_body_user(NonCanonical, X, Y, Context, Clause, !Info) :-
         NonCanonical = noncanon_abstract(_IsSolverType),
         unexpected($pred,
             "trying to create unify proc for abstract noncanonical type")
+    ;
+        NonCanonical = noncanon_subtype,
+        unexpected($pred, "trying to create unify proc for subtype")
     ;
         ( NonCanonical = noncanon_uni_cmp(UnifyPredName, _)
         ; NonCanonical = noncanon_uni_only(UnifyPredName)
@@ -940,10 +941,18 @@ generate_compare_proc_body(SpecDefnInfo, Res, X, Y, Clause, !Info) :-
     info_get_module_info(!.Info, ModuleInfo),
     TypeBody = SpecDefnInfo ^ spdi_type_body,
     Context = SpecDefnInfo ^ spdi_context,
-    % We used to special-case the type_ctors for which
-    % is_type_ctor_a_builtin_dummy(TypeCtor) = is_builtin_dummy_type_ctor,
-    % but both those types now have user-defined unify and compare preds.
     ( if
+        TypeBody = hlds_du_type(_, yes(SuperType), _, _, _)
+    then
+        % Compare subtype terms after casting to base type.
+        TVarSet = SpecDefnInfo ^ spdi_tvarset,
+        get_du_base_type(ModuleInfo, TVarSet, SuperType, BaseType),
+        generate_compare_proc_body_eqv(Context, BaseType, Res, X, Y,
+            Clause, !Info)
+    else if
+        % We used to special-case the type_ctors for which
+        % is_type_ctor_a_builtin_dummy(TypeCtor) = is_builtin_dummy_type_ctor,
+        % but both those types now have user-defined unify and compare preds.
         type_body_has_user_defined_equality_pred(ModuleInfo, TypeBody,
             UserEqComp)
     then
@@ -989,53 +998,44 @@ generate_compare_proc_body(SpecDefnInfo, Res, X, Y, Clause, !Info) :-
                 Res, X, Y, Clause, !Info)
         ;
             TypeBody = hlds_du_type(_, MaybeSuperType, _, MaybeRepn, _),
+            expect(unify(MaybeSuperType, no), $pred, "MaybeSuperType != no"),
             (
                 MaybeRepn = no,
                 unexpected($pred, "MaybeRepn = no")
             ;
                 MaybeRepn = yes(Repn)
             ),
-            ( if
-                MaybeSuperType = yes(SuperType),
-                TVarSet = SpecDefnInfo ^ spdi_tvarset,
-                get_du_base_type(ModuleInfo, TVarSet, SuperType, BaseType)
-            then
-                % Compare after casting to base type.
-                generate_compare_proc_body_eqv(Context, BaseType, Res, X, Y,
-                    Clause, !Info)
-            else
-                DuTypeKind = Repn ^ dur_kind,
+            DuTypeKind = Repn ^ dur_kind,
+            (
+                ( DuTypeKind = du_type_kind_mercury_enum
+                ; DuTypeKind = du_type_kind_foreign_enum(_)
+                ),
+                generate_compare_proc_body_enum(Context,
+                    Res, X, Y, Clause, !Info)
+            ;
+                DuTypeKind = du_type_kind_direct_dummy,
+                generate_compare_proc_body_dummy(Context,
+                    Res, X, Y, Clause, !Info)
+            ;
+                DuTypeKind = du_type_kind_notag(_, ArgType, _),
+                ArgIsDummy = is_type_a_dummy(ModuleInfo, ArgType),
                 (
-                    ( DuTypeKind = du_type_kind_mercury_enum
-                    ; DuTypeKind = du_type_kind_foreign_enum(_)
-                    ),
-                    generate_compare_proc_body_enum(Context,
-                        Res, X, Y, Clause, !Info)
-                ;
-                    DuTypeKind = du_type_kind_direct_dummy,
+                    ArgIsDummy = is_dummy_type,
+                    % Treat this type as if it were a dummy type
+                    % itself.
                     generate_compare_proc_body_dummy(Context,
                         Res, X, Y, Clause, !Info)
                 ;
-                    DuTypeKind = du_type_kind_notag(_, ArgType, _),
-                    ArgIsDummy = is_type_a_dummy(ModuleInfo, ArgType),
-                    (
-                        ArgIsDummy = is_dummy_type,
-                        % Treat this type as if it were a dummy type
-                        % itself.
-                        generate_compare_proc_body_dummy(Context,
-                            Res, X, Y, Clause, !Info)
-                    ;
-                        ArgIsDummy = is_not_dummy_type,
-                        CtorRepns = Repn ^ dur_ctor_repns,
-                        generate_compare_proc_body_du(SpecDefnInfo,
-                            CtorRepns, Res, X, Y, Clause, !Info)
-                    )
-                ;
-                    DuTypeKind = du_type_kind_general,
+                    ArgIsDummy = is_not_dummy_type,
                     CtorRepns = Repn ^ dur_ctor_repns,
                     generate_compare_proc_body_du(SpecDefnInfo,
                         CtorRepns, Res, X, Y, Clause, !Info)
                 )
+            ;
+                DuTypeKind = du_type_kind_general,
+                CtorRepns = Repn ^ dur_ctor_repns,
+                generate_compare_proc_body_du(SpecDefnInfo,
+                    CtorRepns, Res, X, Y, Clause, !Info)
             )
         )
     ).
@@ -1062,6 +1062,9 @@ generate_compare_proc_body_user(Context, NonCanonical, Res, X, Y,
         NonCanonical = noncanon_abstract(_),
         unexpected($pred,
             "trying to create compare proc for abstract noncanonical type")
+    ;
+        NonCanonical = noncanon_subtype,
+        unexpected($pred, "trying to create compare proc for subtype")
     ;
         NonCanonical = noncanon_uni_only(_),
         % Just generate code that will call error/1.
diff --git a/doc/reference_manual.texi b/doc/reference_manual.texi
index 2615b2be3..3f8db1298 100644
--- a/doc/reference_manual.texi
+++ b/doc/reference_manual.texi
@@ -2649,6 +2649,10 @@ Example:
     --->    [T | list(T)].
 @end example
 
+Subtypes may not have user-defined equality or comparison predicates.
+The base type of a subtype may have user-defined equality or comparison,
+in which case values of the subtype will be compared using those predicates.
+
 There is no special interaction between subtypes and the type class system.
 
 @c -----------------------------------------------------------------------
@@ -4557,7 +4561,8 @@ the standard definition of equality is not the desired one;
 we want equality on sets to mean equality of the abstract values,
 not equality of their representations.
 To support such types, Mercury allows programmers to specify
-a user-defined equality predicate for user-defined types:
+a user-defined equality predicate for user-defined types
+(not including subtypes):
 
 @example
 :- type set(T)
diff --git a/tests/hard_coded/Mmakefile b/tests/hard_coded/Mmakefile
index 3d9947225..2eabcc120 100644
--- a/tests/hard_coded/Mmakefile
+++ b/tests/hard_coded/Mmakefile
@@ -398,6 +398,7 @@ ORDINARY_PROGS = \
 	subtype_order \
 	subtype_pack \
 	subtype_rtti \
+	subtype_user_compare \
 	sv_nested_closures \
 	sv_record_update \
 	switch_detect \
diff --git a/tests/hard_coded/subtype_user_compare.exp b/tests/hard_coded/subtype_user_compare.exp
new file mode 100644
index 000000000..acec65960
--- /dev/null
+++ b/tests/hard_coded/subtype_user_compare.exp
@@ -0,0 +1,4 @@
+compare fruit (local):  '>'
+compare citrus (local): '>'
+compare fruit (abstract): '>'
+compare citrus (abstract): '>'
diff --git a/tests/hard_coded/subtype_user_compare.m b/tests/hard_coded/subtype_user_compare.m
new file mode 100644
index 000000000..b08b86d35
--- /dev/null
+++ b/tests/hard_coded/subtype_user_compare.m
@@ -0,0 +1,70 @@
+%---------------------------------------------------------------------------%
+% vim: ts=4 sw=4 et ft=mercury
+%---------------------------------------------------------------------------%
+
+:- module subtype_user_compare.
+:- interface.
+
+:- import_module io.
+
+:- pred main(io::di, io::uo) is det.
+
+%---------------------------------------------------------------------------%
+
+:- implementation.
+
+:- import_module subtype_user_compare2.
+
+:- type fruit
+    --->    apple0
+    ;       banana1
+    ;       lemon2
+    ;       orange3
+    ;       peach4
+    ;       pomelo5
+    ;       tomato6
+    where comparison is fruit_compare.
+
+:- type citrus =< fruit
+    --->    lemon2
+    ;       orange3
+    ;       pomelo5.
+
+:- pred fruit_compare(comparison_result::uo, fruit::in, fruit::in) is det.
+
+fruit_compare(Res, A, B) :-
+    IntA = fruit_int(A),
+    IntB = fruit_int(B),
+    compare(Res, IntA, IntB).
+
+:- func fruit_int(fruit) = int.
+
+fruit_int(Fruit) = Int :-
+    promise_equivalent_solutions [Int]
+    ( Fruit = apple0, Int = 0
+    ; Fruit = banana1, Int = -1
+    ; Fruit = lemon2, Int = -2
+    ; Fruit = orange3, Int = -3
+    ; Fruit = peach4, Int = -4
+    ; Fruit = pomelo5, Int = -5
+    ; Fruit = tomato6, Int = -6
+    ).
+
+main(!IO) :-
+    io.write_string("compare fruit (local):  ", !IO),
+    test_compare(lemon2 : fruit, pomelo5 : fruit, !IO),
+
+    io.write_string("compare citrus (local): ", !IO),
+    test_compare(lemon2 : citrus, pomelo5 : citrus, !IO),
+
+    io.write_string("compare fruit (abstract): ", !IO),
+    test_compare(abs_fruit_lemon, abs_fruit_pomelo, !IO),
+
+    io.write_string("compare citrus (abstract): ", !IO),
+    test_compare(abs_citrus_lemon, abs_citrus_pomelo, !IO).
+
+:- pred test_compare(T::in, T::in, io::di, io::uo) is det.
+
+test_compare(A, B, !IO) :-
+    compare(Res, A, B),
+    io.print_line(Res, !IO).
diff --git a/tests/hard_coded/subtype_user_compare2.m b/tests/hard_coded/subtype_user_compare2.m
new file mode 100644
index 000000000..121dd1f9d
--- /dev/null
+++ b/tests/hard_coded/subtype_user_compare2.m
@@ -0,0 +1,64 @@
+%---------------------------------------------------------------------------%
+% vim: ts=4 sw=4 et ft=mercury
+%---------------------------------------------------------------------------%
+
+:- module subtype_user_compare2.
+:- interface.
+
+:- type abs_fruit.
+
+:- type abs_citrus.
+
+:- func abs_fruit_lemon = abs_fruit.
+:- func abs_fruit_pomelo = abs_fruit.
+
+:- func abs_citrus_lemon = abs_citrus.
+:- func abs_citrus_pomelo = abs_citrus.
+
+%---------------------------------------------------------------------------%
+
+:- implementation.
+
+:- type abs_fruit == fruit.
+:- type abs_citrus == citrus.
+
+:- type fruit
+    --->    apple0
+    ;       banana1
+    ;       lemon2
+    ;       orange3
+    ;       peach4
+    ;       pomelo5
+    ;       tomato6
+    where comparison is fruit_compare.
+
+:- type citrus =< fruit
+    --->    lemon2
+    ;       orange3
+    ;       pomelo5.
+
+:- pred fruit_compare(comparison_result::uo, fruit::in, fruit::in) is det.
+
+fruit_compare(Res, A, B) :-
+    IntA = fruit_int(A),
+    IntB = fruit_int(B),
+    compare(Res, IntA, IntB).
+
+:- func fruit_int(fruit) = int.
+
+fruit_int(Fruit) = Int :-
+    promise_equivalent_solutions [Int]
+    ( Fruit = apple0, Int = 0
+    ; Fruit = banana1, Int = -1
+    ; Fruit = lemon2, Int = -2
+    ; Fruit = orange3, Int = -3
+    ; Fruit = peach4, Int = -4
+    ; Fruit = pomelo5, Int = -5
+    ; Fruit = tomato6, Int = -6
+    ).
+
+abs_fruit_lemon = lemon2.
+abs_fruit_pomelo = pomelo5.
+
+abs_citrus_lemon = lemon2.
+abs_citrus_pomelo = pomelo5.
diff --git a/tests/invalid/Mmakefile b/tests/invalid/Mmakefile
index 1ffc5be83..88348406b 100644
--- a/tests/invalid/Mmakefile
+++ b/tests/invalid/Mmakefile
@@ -313,6 +313,7 @@ SINGLEMODULE= \
 	subtype_invalid_supertype \
 	subtype_not_subset \
 	subtype_syntax \
+	subtype_user_compare \
 	switch_arm_multi_not_det \
 	tc_err1 \
 	tc_err2 \
diff --git a/tests/invalid/subtype_user_compare.err_exp b/tests/invalid/subtype_user_compare.err_exp
new file mode 100644
index 000000000..c2dc66217
--- /dev/null
+++ b/tests/invalid/subtype_user_compare.err_exp
@@ -0,0 +1,2 @@
+subtype_user_compare.m:015: Error: the subtype `citrus'/0 is not allowed to
+subtype_user_compare.m:015:   have user-defined equality or comparison.
diff --git a/tests/invalid/subtype_user_compare.m b/tests/invalid/subtype_user_compare.m
new file mode 100644
index 000000000..a0cf15457
--- /dev/null
+++ b/tests/invalid/subtype_user_compare.m
@@ -0,0 +1,28 @@
+%---------------------------------------------------------------------------%
+% vim: ts=4 sw=4 et ft=mercury
+%---------------------------------------------------------------------------%
+
+:- module subtype_user_compare.
+:- interface.
+
+:- type fruit
+    --->    apple
+    ;       orange
+    ;       lemon.
+
+:- implementation.
+
+:- type citrus =< fruit
+    --->    orange
+    ;       lemon
+    where   equality is citrus_equal,
+            comparison is citrus_compare.
+
+:- pred citrus_equal(citrus::in, citrus::in) is semidet.
+
+citrus_equal(_, _) :-
+    semidet_true.
+
+:- pred citrus_compare(comparison_result::uo, citrus::in, citrus::in) is det.
+
+citrus_compare(=, _, _).
-- 
2.30.0



More information about the reviews mailing list