combine check_and_sort_domain with main function; rewrite much of the error-checking code
authorhuffman
Wed, 20 Oct 2010 16:19:25 -0700
changeset 40044 89381a2f8864
parent 40043 3007608368e2
child 40045 e0f372e18f3e
combine check_and_sort_domain with main function; rewrite much of the error-checking code
src/HOLCF/Tools/Domain/domain.ML
--- a/src/HOLCF/Tools/Domain/domain.ML	Wed Oct 20 13:22:30 2010 -0700
+++ b/src/HOLCF/Tools/Domain/domain.ML	Wed Oct 20 16:19:25 2010 -0700
@@ -41,94 +41,13 @@
 fun second (_,x,_) = x;
 fun third  (_,_,x) = x;
 
-fun upd_first  f (x,y,z) = (f x,   y,   z);
-fun upd_second f (x,y,z) = (  x, f y,   z);
-fun upd_third  f (x,y,z) = (  x,   y, f z);
-
-(* ----- general testing and preprocessing of constructor list -------------- *)
-fun check_and_sort_domain
-    (arg_sort : bool -> sort)
-    (dtnvs : (string * typ list) list)
-    (cons'' : (binding * (bool * binding option * typ) list * mixfix) list list)
-    (thy : theory)
-    : (binding * (bool * binding option * typ) list * mixfix) list list =
-  let
-    val defaultS = Sign.defaultS thy;
-
-    val all_cons = map (Binding.name_of o first) (flat cons'');
-    val test_dupl_cons =
-      case duplicates (op =) all_cons of 
-        [] => false | dups => error ("Duplicate constructors: " 
-                                      ^ commas_quote dups);
-    val all_sels =
-      (map Binding.name_of o map_filter second o maps second) (flat cons'');
-    val test_dupl_sels =
-      case duplicates (op =) all_sels of
-        [] => false | dups => error("Duplicate selectors: "^commas_quote dups);
-
-    fun test_dupl_tvars s =
-      case duplicates (op =) (map(fst o dest_TFree)s) of
-        [] => false | dups => error("Duplicate type arguments: " 
-                                    ^commas_quote dups);
-    val test_dupl_tvars' = exists test_dupl_tvars (map snd dtnvs);
-
-    (* test for free type variables, illegal sort constraints on rhs,
-       non-pcpo-types and invalid use of recursive type;
-       replace sorts in type variables on rhs *)
-    fun analyse_equation ((dname,typevars),cons') = 
-      let
-        val tvars = map dest_TFree typevars;
-        fun rm_sorts (TFree(s,_)) = TFree(s,[])
-          | rm_sorts (Type(s,ts)) = Type(s,remove_sorts ts)
-          | rm_sorts (TVar(s,_))  = TVar(s,[])
-        and remove_sorts l = map rm_sorts l;
-        fun analyse indirect (TFree(v,s))  =
-            (case AList.lookup (op =) tvars v of 
-               NONE => error ("Free type variable " ^ quote v ^ " on rhs.")
-             | SOME sort => if eq_set (op =) (s, defaultS) orelse
-                               eq_set (op =) (s, sort)
-                            then TFree(v,sort)
-                            else error ("Inconsistent sort constraint" ^
-                                        " for type variable " ^ quote v))
-          | analyse indirect (t as Type(s,typl)) =
-            (case AList.lookup (op =) dtnvs s of
-               NONE => Type (s, map (analyse false) typl)
-             | SOME typevars =>
-                 if indirect 
-                 then error ("Indirect recursion of type " ^ 
-                             quote (Syntax.string_of_typ_global thy t))
-                 else if dname <> s orelse
-                         (** BUG OR FEATURE?:
-                             mutual recursion may use different arguments **)
-                         remove_sorts typevars = remove_sorts typl 
-                 then Type(s,map (analyse true) typl)
-                 else error ("Direct recursion of type " ^ 
-                             quote (Syntax.string_of_typ_global thy t) ^ 
-                             " with different arguments"))
-          | analyse indirect (TVar _) = error "extender:analyse";
-        (* a lazy argument may have an unpointed type *)
-        (* unless the argument has a selector function *)
-        fun check_pcpo sel lazy T =
-            let val sort = arg_sort (lazy andalso is_none sel) in
-              if Sign.of_sort thy (T, sort) then T
-              else error ("Constructor argument type is not of sort " ^
-                          Syntax.string_of_sort_global thy sort ^ ": " ^
-                          Syntax.string_of_typ_global thy T)
-            end;
-        fun analyse_arg (lazy, sel, T) =
-            (lazy, sel, check_pcpo sel lazy (analyse false T));
-        fun analyse_con (b, args, mx) = (b, map analyse_arg args, mx);
-      in map analyse_con cons' end; 
-  in ListPair.map analyse_equation (dtnvs,cons'')
-  end; (* let *)
-
 (* ----- calls for building new thy and thms -------------------------------- *)
 
 type info =
      Domain_Take_Proofs.iso_info list * Domain_Take_Proofs.take_induct_info;
 
 fun gen_add_domain
-    (prep_typ : theory -> 'a -> typ)
+    (prep_typ : theory -> (string * sort) list -> 'a -> typ)
     (add_isos : (binding * mixfix * (typ * typ)) list -> theory -> info * theory)
     (arg_sort : bool -> sort)
     (comp_dbind : binding)
@@ -142,8 +61,8 @@
           | readS NONE = Sign.defaultS thy;
         fun readTFree (a, s) = TFree (a, readS s);
       in
-        map (fn (vs,dname:binding,mx,_) =>
-                (dname, map readTFree vs, mx)) raw_specs
+        map (fn (vs, dbind, mx, _) =>
+                (dbind, map readTFree vs, mx)) raw_specs
       end;
 
     fun thy_type (dbind, tvars, mx) = (dbind, length tvars, mx);
@@ -158,25 +77,87 @@
 
     val dbinds : binding list =
         map (fn (_,dbind,_,_) => dbind) raw_specs;
-    val raw_conss :
+    val raw_rhss :
         (binding * (bool * binding option * 'a) list * mixfix) list list =
         map (fn (_,_,_,cons) => cons) raw_specs;
-    val conss :
-        (binding * (bool * binding option * typ) list * mixfix) list list =
-        map (map (upd_second (map (upd_third (prep_typ tmp_thy))))) raw_conss;
     val dtnvs' : (string * typ list) list =
         map (fn (dbind, vs, mx) => (Sign.full_name thy dbind, vs)) dtnvs;
-    val conss :
-        (binding * (bool * binding option * typ) list * mixfix) list list =
-        check_and_sort_domain arg_sort dtnvs' conss tmp_thy;
+
+    val all_cons = map (Binding.name_of o first) (flat raw_rhss);
+    val test_dupl_cons =
+      case duplicates (op =) all_cons of 
+        [] => false | dups => error ("Duplicate constructors: " 
+                                      ^ commas_quote dups);
+    val all_sels =
+      (map Binding.name_of o map_filter second o maps second) (flat raw_rhss);
+    val test_dupl_sels =
+      case duplicates (op =) all_sels of
+        [] => false | dups => error("Duplicate selectors: "^commas_quote dups);
+
+    fun test_dupl_tvars s =
+      case duplicates (op =) (map(fst o dest_TFree)s) of
+        [] => false | dups => error("Duplicate type arguments: " 
+                                    ^commas_quote dups);
+    val test_dupl_tvars' = exists test_dupl_tvars (map snd dtnvs');
+
+    val sorts : (string * sort) list =
+      let val all_sorts = map (map dest_TFree o snd) dtnvs';
+      in
+        case distinct (eq_set (op =)) all_sorts of
+          [sorts] => sorts
+        | _ => error "Mutually recursive domains must have same type parameters"
+      end;
+
+    (* a lazy argument may have an unpointed type *)
+    (* unless the argument has a selector function *)
+    fun check_pcpo (lazy, sel, T) =
+      let val sort = arg_sort (lazy andalso is_none sel) in
+        if Sign.of_sort tmp_thy (T, sort) then ()
+        else error ("Constructor argument type is not of sort " ^
+                    Syntax.string_of_sort_global tmp_thy sort ^ ": " ^
+                    Syntax.string_of_typ_global tmp_thy T)
+      end;
+
+    (* test for free type variables, illegal sort constraints on rhs,
+       non-pcpo-types and invalid use of recursive type;
+       replace sorts in type variables on rhs *)
+    val map_tab = Domain_Take_Proofs.get_map_tab thy;
+    fun check_rec rec_ok (T as TFree (v,_))  =
+        if AList.defined (op =) sorts v then T
+        else error ("Free type variable " ^ quote v ^ " on rhs.")
+      | check_rec rec_ok (T as Type (s, Ts)) =
+        (case AList.lookup (op =) dtnvs' s of
+          NONE =>
+            let val rec_ok' = rec_ok andalso Symtab.defined map_tab s;
+            in Type (s, map (check_rec rec_ok') Ts) end
+        | SOME typevars =>
+          if typevars <> Ts
+          then error ("Recursion of type " ^ 
+                      quote (Syntax.string_of_typ_global tmp_thy T) ^ 
+                      " with different arguments")
+          else if rec_ok then T
+          else error ("Illegal indirect recursion of type " ^ 
+                      quote (Syntax.string_of_typ_global tmp_thy T)))
+      | check_rec rec_ok (TVar _) = error "extender:check_rec";
+
+    fun prep_arg (lazy, sel, raw_T) =
+      let
+        val T = prep_typ tmp_thy sorts raw_T;
+        val _ = check_rec true T;
+        val _ = check_pcpo (lazy, sel, T);
+      in (lazy, sel, T) end;
+    fun prep_con (b, args, mx) = (b, map prep_arg args, mx);
+    fun prep_rhs cons = map prep_con cons;
+    val rhss : (binding * (bool * binding option * typ) list * mixfix) list list =
+        map prep_rhs raw_rhss;
 
     fun mk_arg_typ (lazy, dest_opt, T) = if lazy then mk_upT T else T;
     fun mk_con_typ (bind, args, mx) =
         if null args then oneT else foldr1 mk_sprodT (map mk_arg_typ args);
-    fun mk_eq_typ cons = foldr1 mk_ssumT (map mk_con_typ cons);
+    fun mk_rhs_typ cons = foldr1 mk_ssumT (map mk_con_typ cons);
 
     val absTs : typ list = map Type dtnvs';
-    val repTs : typ list = map mk_eq_typ conss;
+    val repTs : typ list = map mk_rhs_typ rhss;
 
     val iso_spec : (binding * mixfix * (typ * typ)) list =
         map (fn ((dbind, _, mx), eq) => (dbind, mx, eq))
@@ -188,7 +169,7 @@
         thy
           |> fold_map (fn ((dbind, cons), info) =>
                 Domain_Constructors.add_domain_constructors dbind cons info)
-             (dbinds ~~ conss ~~ iso_infos);
+             (dbinds ~~ rhss ~~ iso_infos);
 
     val (take_rews, thy) =
         Domain_Induction.comp_theorems comp_dbind
@@ -209,17 +190,35 @@
 fun pcpo_arg lazy = if lazy then @{sort cpo} else @{sort pcpo};
 fun rep_arg lazy = @{sort bifinite};
 
+(* Adapted from src/HOL/Tools/Datatype/datatype_data.ML *)
+fun read_typ thy sorts str =
+  let
+    val ctxt = ProofContext.init_global thy
+      |> fold (Variable.declare_typ o TFree) sorts;
+  in Syntax.read_typ ctxt str end;
+
+fun cert_typ sign sorts raw_T =
+  let
+    val T = Type.no_tvars (Sign.certify_typ sign raw_T)
+      handle TYPE (msg, _, _) => error msg;
+    val sorts' = Term.add_tfreesT T sorts;
+    val _ =
+      case duplicates (op =) (map fst sorts') of
+        [] => ()
+      | dups => error ("Inconsistent sort constraints for " ^ commas dups)
+  in T end;
+
 val add_domain =
-    gen_add_domain Sign.certify_typ Domain_Axioms.add_axioms pcpo_arg;
+    gen_add_domain cert_typ Domain_Axioms.add_axioms pcpo_arg;
 
 val add_new_domain =
-    gen_add_domain Sign.certify_typ define_isos rep_arg;
+    gen_add_domain cert_typ define_isos rep_arg;
 
 val add_domain_cmd =
-    gen_add_domain Syntax.read_typ_global Domain_Axioms.add_axioms pcpo_arg;
+    gen_add_domain read_typ Domain_Axioms.add_axioms pcpo_arg;
 
 val add_new_domain_cmd =
-    gen_add_domain Syntax.read_typ_global define_isos rep_arg;
+    gen_add_domain read_typ define_isos rep_arg;
 
 
 (** outer syntax **)