src/HOLCF/Tools/Domain/domain_extender.ML
changeset 33798 46cbbcbd4e68
parent 33796 6442aa3773a2
child 33955 fff6f11b1f09
--- a/src/HOLCF/Tools/Domain/domain_extender.ML	Thu Nov 19 15:31:19 2009 -0800
+++ b/src/HOLCF/Tools/Domain/domain_extender.ML	Thu Nov 19 15:41:52 2009 -0800
@@ -17,6 +17,18 @@
       ((string * string option) list * binding * mixfix *
        (binding * (bool * binding option * typ) list * mixfix) list) list
       -> theory -> theory
+
+  val add_new_domain_cmd:
+      string ->
+      ((string * string option) list * binding * mixfix *
+       (binding * (bool * binding option * string) list * mixfix) list) list
+      -> theory -> theory
+
+  val add_new_domain:
+      string ->
+      ((string * string option) list * binding * mixfix *
+       (binding * (bool * binding option * typ) list * mixfix) list) list
+      -> theory -> theory
 end;
 
 structure Domain_Extender :> DOMAIN_EXTENDER =
@@ -26,13 +38,14 @@
 
 (* ----- general testing and preprocessing of constructor list -------------- *)
 fun check_and_sort_domain
+    (definitional : bool)
     (dtnvs : (string * typ list) list)
     (cons'' : (binding * (bool * binding option * typ) list * mixfix) list list)
-    (sg : theory)
+    (thy : theory)
     : ((string * typ list) *
        (binding * (bool * binding option * typ) list * mixfix) list) list =
   let
-    val defaultS = Sign.defaultS sg;
+    val defaultS = Sign.defaultS thy;
 
     val test_dupl_typs =
       case duplicates (op =) (map fst dtnvs) of 
@@ -78,27 +91,27 @@
           | analyse indirect (t as Type(s,typl)) =
             (case AList.lookup (op =) dtnvs s of
                NONE =>
-                 if s mem indirect_ok
+                 if definitional orelse s mem indirect_ok
                  then Type(s,map (analyse false) typl)
                  else Type(s,map (analyse true) typl)
              | SOME typevars =>
                  if indirect 
                  then error ("Indirect recursion of type " ^ 
-                             quote (string_of_typ sg t))
+                             quote (string_of_typ thy t))
                  else if dname <> s orelse
                          (** BUG OR FEATURE?:
                              mutual recursion may use different arguments **)
                          remove_sorts typevars = remove_sorts typl 
                  then Type(s,map (analyse true) typl)
                  else error ("Direct recursion of type " ^ 
-                             quote (string_of_typ sg t) ^ 
+                             quote (string_of_typ thy t) ^ 
                              " with different arguments"))
           | analyse indirect (TVar _) = Imposs "extender:analyse";
         fun check_pcpo lazy T =
             let val ok = if lazy then cpo_type else pcpo_type
-            in if ok sg T then T
+            in if ok thy T then T
                else error ("Constructor argument type is not of sort pcpo: " ^
-                           string_of_typ sg T)
+                           string_of_typ thy T)
             end;
         fun analyse_arg (lazy, sel, T) =
             (lazy, sel, check_pcpo lazy (analyse false T));
@@ -126,7 +139,8 @@
     fun thy_type  (dname,tvars,mx) = (dname, length tvars, mx);
     fun thy_arity (dname,tvars,mx) =
         (Sign.full_name thy''' dname, map (snd o dest_TFree) tvars, pcpoS);
-    val thy'' = thy'''
+    val thy'' =
+      thy'''
       |> Sign.add_types (map thy_type dtnvs)
       |> fold (AxClass.axiomatize_arity o thy_arity) dtnvs;
     val cons'' =
@@ -135,8 +149,8 @@
       map (fn (dname,vs,mx) => (Sign.full_name thy''' dname,vs)) dtnvs;
     val eqs' : ((string * typ list) *
         (binding * (bool * binding option * typ) list * mixfix) list) list =
-      check_and_sort_domain dtnvs' cons'' thy'';
-    val thy' = thy'' |> Domain_Syntax.add_syntax comp_dnam eqs';
+      check_and_sort_domain false dtnvs' cons'' thy'';
+    val thy' = thy'' |> Domain_Syntax.add_syntax false comp_dnam eqs';
     val dts  = map (Type o fst) eqs';
     val new_dts = map (fn ((s,Ts),_) => (s, map (fst o dest_TFree) Ts)) eqs';
     fun strip ss = Library.drop (find_index (fn s => s = "'") ss + 1, ss);
@@ -154,7 +168,82 @@
         ) : cons;
     val eqs : eq list =
         map (fn (dtnvs,cons') => (dtnvs, map one_con cons')) eqs';
-    val thy = thy' |> Domain_Axioms.add_axioms comp_dnam eqs;
+    val thy = thy' |> Domain_Axioms.add_axioms false comp_dnam eqs;
+    val ((rewss, take_rews), theorems_thy) =
+        thy
+          |> fold_map (fn eq => Domain_Theorems.theorems (eq, eqs)) eqs
+          ||>> Domain_Theorems.comp_theorems (comp_dnam, eqs);
+  in
+    theorems_thy
+      |> Sign.add_path (Long_Name.base_name comp_dnam)
+      |> PureThy.add_thmss
+           [((Binding.name "rews", flat rewss @ take_rews), [])]
+      |> snd
+      |> Sign.parent_path
+  end;
+
+fun gen_add_new_domain
+    (prep_typ : theory -> 'a -> typ)
+    (comp_dnam : string)
+    (eqs''' : ((string * string option) list * binding * mixfix *
+               (binding * (bool * binding option * 'a) list * mixfix) list) list)
+    (thy''' : theory) =
+  let
+    fun readS (SOME s) = Syntax.read_sort_global thy''' s
+      | readS NONE = Sign.defaultS thy''';
+    fun readTFree (a, s) = TFree (a, readS s);
+
+    val dtnvs = map (fn (vs,dname:binding,mx,_) => 
+                        (dname, map readTFree vs, mx)) eqs''';
+    val cons''' = map (fn (_,_,_,cons) => cons) eqs''';
+    fun thy_type  (dname,tvars,mx) = (dname, length tvars, mx);
+    fun thy_arity (dname,tvars,mx) =
+      (Sign.full_name thy''' dname, map (snd o dest_TFree) tvars, @{sort rep});
+
+    (* this theory is used just for parsing and error checking *)
+    val tmp_thy = thy'''
+      |> Theory.copy
+      |> Sign.add_types (map thy_type dtnvs)
+      |> fold (AxClass.axiomatize_arity o thy_arity) dtnvs;
+
+    val cons'' : (binding * (bool * binding option * typ) list * mixfix) list list =
+      map (map (upd_second (map (upd_third (prep_typ tmp_thy))))) cons''';
+    val dtnvs' : (string * typ list) list =
+      map (fn (dname,vs,mx) => (Sign.full_name thy''' dname,vs)) dtnvs;
+    val eqs' : ((string * typ list) *
+        (binding * (bool * binding option * typ) list * mixfix) list) list =
+      check_and_sort_domain true dtnvs' cons'' tmp_thy;
+
+    fun mk_arg_typ (lazy, dest_opt, T) = if lazy then mk_uT T else T;
+    fun mk_con_typ (bind, args, mx) =
+        if null args then oneT else foldr1 mk_sprodT (map mk_arg_typ args);
+    fun mk_eq_typ (_, cons) = foldr1 mk_ssumT (map mk_con_typ cons);
+    
+    val thy'' = thy''' |>
+      Domain_Isomorphism.domain_isomorphism
+        (map (fn ((vs, dname, mx, _), eq) =>
+                 (map fst vs, dname, mx, mk_eq_typ eq))
+             (eqs''' ~~ eqs'))
+
+    val thy' = thy'' |> Domain_Syntax.add_syntax true comp_dnam eqs';
+    val dts  = map (Type o fst) eqs';
+    val new_dts = map (fn ((s,Ts),_) => (s, map (fst o dest_TFree) Ts)) eqs';
+    fun strip ss = Library.drop (find_index (fn s => s = "'") ss + 1, ss);
+    fun typid (Type  (id,_)) =
+        let val c = hd (Symbol.explode (Long_Name.base_name id))
+        in if Symbol.is_letter c then c else "t" end
+      | typid (TFree (id,_)   ) = hd (strip (tl (Symbol.explode id)))
+      | typid (TVar ((id,_),_)) = hd (tl (Symbol.explode id));
+    fun one_con (con,args,mx) =
+        ((Syntax.const_name mx (Binding.name_of con)),
+         ListPair.map (fn ((lazy,sel,tp),vn) =>
+           mk_arg ((lazy, DatatypeAux.dtyp_of_typ new_dts tp),
+                   Option.map Binding.name_of sel,vn))
+                      (args,(mk_var_names(map (typid o third) args)))
+        ) : cons;
+    val eqs : eq list =
+        map (fn (dtnvs,cons') => (dtnvs, map one_con cons')) eqs';
+    val thy = thy' |> Domain_Axioms.add_axioms true comp_dnam eqs;
     val ((rewss, take_rews), theorems_thy) =
         thy
           |> fold_map (fn eq => Domain_Theorems.theorems (eq, eqs)) eqs
@@ -171,6 +260,9 @@
 val add_domain = gen_add_domain Sign.certify_typ;
 val add_domain_cmd = gen_add_domain Syntax.read_typ_global;
 
+val add_new_domain = gen_add_new_domain Sign.certify_typ;
+val add_new_domain_cmd = gen_add_new_domain Syntax.read_typ_global;
+
 
 (** outer syntax **)
 
@@ -205,6 +297,7 @@
     P.and_list1 domain_decl;
 
 fun mk_domain
+    (definitional : bool)
     (opt_name : string option,
      doms : ((((string * string option) list * binding) * mixfix) *
              ((binding * (bool * binding option * string) list) * mixfix) list) list ) =
@@ -216,11 +309,19 @@
                 (vs, t, mx, map (fn ((c, ds), mx) => (c, ds, mx)) cons)) doms;
     val comp_dnam =
         case opt_name of NONE => space_implode "_" names | SOME s => s;
-  in add_domain_cmd comp_dnam specs end;
+  in
+    if definitional 
+    then add_new_domain_cmd comp_dnam specs
+    else add_domain_cmd comp_dnam specs
+  end;
 
 val _ =
   OuterSyntax.command "domain" "define recursive domains (HOLCF)"
-    K.thy_decl (domains_decl >> (Toplevel.theory o mk_domain));
+    K.thy_decl (domains_decl >> (Toplevel.theory o mk_domain false));
+
+val _ =
+  OuterSyntax.command "new_domain" "define recursive domains (HOLCF)"
+    K.thy_decl (domains_decl >> (Toplevel.theory o mk_domain true));
 
 end;