make domain package work with non-cpo argument types
authorhuffman
Wed, 27 Oct 2010 11:10:36 -0700
changeset 40218 f7d4d023a899
parent 40217 656bb85f01ab
child 40219 b283680d8044
make domain package work with non-cpo argument types
src/HOLCF/Library/Strict_Fun.thy
src/HOLCF/Powerdomains.thy
src/HOLCF/Representable.thy
src/HOLCF/Tools/Domain/domain_isomorphism.ML
src/HOLCF/Tools/Domain/domain_take_proofs.ML
--- a/src/HOLCF/Library/Strict_Fun.thy	Wed Oct 27 11:06:53 2010 -0700
+++ b/src/HOLCF/Library/Strict_Fun.thy	Wed Oct 27 11:10:36 2010 -0700
@@ -213,7 +213,7 @@
 
 setup {*
   Domain_Isomorphism.add_type_constructor
-    (@{type_name "sfun"}, @{term sfun_defl}, @{const_name sfun_map})
+    (@{type_name "sfun"}, @{const_name sfun_defl}, @{const_name sfun_map}, [true, true])
 *}
 
 end
--- a/src/HOLCF/Powerdomains.thy	Wed Oct 27 11:06:53 2010 -0700
+++ b/src/HOLCF/Powerdomains.thy	Wed Oct 27 11:10:36 2010 -0700
@@ -43,9 +43,9 @@
 
 setup {*
   fold Domain_Isomorphism.add_type_constructor
-    [(@{type_name "upper_pd"}, @{term upper_defl}, @{const_name upper_map}),
-     (@{type_name "lower_pd"}, @{term lower_defl}, @{const_name lower_map}),
-     (@{type_name "convex_pd"}, @{term convex_defl}, @{const_name convex_map})]
+    [(@{type_name "upper_pd"}, @{const_name upper_defl}, @{const_name upper_map}, [true]),
+     (@{type_name "lower_pd"}, @{const_name lower_defl}, @{const_name lower_map}, [true]),
+     (@{type_name "convex_pd"}, @{const_name convex_defl}, @{const_name convex_map}, [true])]
 *}
 
 end
--- a/src/HOLCF/Representable.thy	Wed Oct 27 11:06:53 2010 -0700
+++ b/src/HOLCF/Representable.thy	Wed Oct 27 11:10:36 2010 -0700
@@ -284,11 +284,11 @@
 
 setup {*
   fold Domain_Isomorphism.add_type_constructor
-    [(@{type_name cfun}, @{term cfun_defl}, @{const_name cfun_map}),
-     (@{type_name ssum}, @{term ssum_defl}, @{const_name ssum_map}),
-     (@{type_name sprod}, @{term sprod_defl}, @{const_name sprod_map}),
-     (@{type_name prod}, @{term cprod_defl}, @{const_name cprod_map}),
-     (@{type_name "u"}, @{term u_defl}, @{const_name u_map})]
+    [(@{type_name cfun}, @{const_name cfun_defl}, @{const_name cfun_map}, [true, true]),
+     (@{type_name ssum}, @{const_name ssum_defl}, @{const_name ssum_map}, [true, true]),
+     (@{type_name sprod}, @{const_name sprod_defl}, @{const_name sprod_map}, [true, true]),
+     (@{type_name prod}, @{const_name prod_defl}, @{const_name cprod_map}, [true, true]),
+     (@{type_name "u"}, @{const_name u_defl}, @{const_name u_map}, [true])]
 *}
 
 end
--- a/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Wed Oct 27 11:06:53 2010 -0700
+++ b/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Wed Oct 27 11:10:36 2010 -0700
@@ -27,8 +27,9 @@
   val domain_isomorphism_cmd :
     (string list * binding * mixfix * string * (binding * binding) option) list
       -> theory -> theory
+
   val add_type_constructor :
-    (string * term * string) -> theory -> theory
+    (string * string * string * bool list) -> theory -> theory
 
   val setup : theory -> theory
 end;
@@ -44,14 +45,19 @@
 
 val beta_tac = simp_tac beta_ss;
 
+fun is_cpo thy T = Sign.of_sort thy (T, @{sort cpo});
+fun is_bifinite thy T = Sign.of_sort thy (T, @{sort bifinite});
+
 (******************************************************************************)
 (******************************** theory data *********************************)
 (******************************************************************************)
 
 structure DeflData = Theory_Data
 (
-  (* terms like "foo_defl" *)
-  type T = term Symtab.table;
+  (* constant names like "foo_defl" *)
+  (* list indicates which type arguments correspond to deflation parameters *)
+  (* alternatively, which type arguments allow indirect recursion *)
+  type T = (string * bool list) Symtab.table;
   val empty = Symtab.empty;
   val extend = I;
   fun merge data = Symtab.merge (K true) data;
@@ -76,9 +82,9 @@
 );
 
 fun add_type_constructor
-  (tname, defl_const, map_name) =
-    DeflData.map (Symtab.insert (K true) (tname, defl_const))
-    #> Domain_Take_Proofs.add_map_function (tname, map_name)
+  (tname, defl_name, map_name, flags) =
+    DeflData.map (Symtab.insert (K true) (tname, (defl_name, flags)))
+    #> Domain_Take_Proofs.add_map_function (tname, map_name, flags)
 
 val setup =
     RepData.setup #> MapIdData.setup #> IsodeflData.setup
@@ -91,15 +97,11 @@
 open HOLCF_Library;
 
 infixr 6 ->>;
-infix -->>;
+infixr -->>;
 
 val udomT = @{typ udom};
 val deflT = @{typ "defl"};
 
-fun mapT (T as Type (_, Ts)) =
-    (map (fn T => T ->> T) Ts) -->> (T ->> T)
-  | mapT T = T ->> T;
-
 fun mk_DEFL T =
   Const (@{const_name defl}, Term.itselfT T --> deflT) $ Logic.mk_type T;
 
@@ -155,6 +157,8 @@
         case lhs of
           Const (@{const_name Rep_CFun}, _) $ f $ (x as Free _) =>
             mk_eqn (f, big_lambda x rhs)
+        | f $ Const (@{const_name TYPE}, T) =>
+            mk_eqn (f, Abs ("t", T, rhs))
         | Const _ => Logic.mk_equals (lhs, rhs)
         | _ => raise TERM ("lhs not of correct form", [lhs, rhs]);
     val eqns = map mk_eqn projs;
@@ -204,17 +208,30 @@
 (****************** deflation combinators and map functions *******************)
 (******************************************************************************)
 
+fun mk_defl_type (flags : bool list) (Ts : typ list) =
+    map (Term.itselfT o snd) (filter_out fst (flags ~~ Ts)) --->
+    map (K deflT) (filter I flags) -->> deflT;
+
 fun defl_of_typ
-    (tab : term Symtab.table)
+    (thy : theory)
+    (tab : (string * bool list) Symtab.table)
     (T : typ) : term =
   let
     fun is_closed_typ (Type (_, Ts)) = forall is_closed_typ Ts
+      | is_closed_typ (TFree (n, s)) = not (Sign.subsort thy (s, @{sort bifinite}))
       | is_closed_typ _ = false;
     fun defl_of (TFree (a, _)) = Free (Library.unprefix "'" a, deflT)
       | defl_of (TVar _) = error ("defl_of_typ: TVar")
       | defl_of (T as Type (c, Ts)) =
         case Symtab.lookup tab c of
-          SOME t => list_ccomb (t, map defl_of Ts)
+          SOME (s, flags) =>
+            let
+              val defl_const = Const (s, mk_defl_type flags Ts);
+              val type_args = map (Logic.mk_type o snd) (filter_out fst (flags ~~ Ts));
+              val defl_args = map (defl_of o snd) (filter fst (flags ~~ Ts));
+            in
+              list_ccomb (list_comb (defl_const, type_args), defl_args)
+            end
         | NONE => if is_closed_typ T
                   then mk_DEFL T
                   else error ("defl_of_typ: type variable under unsupported type constructor " ^ c);
@@ -258,6 +275,10 @@
     val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos;
     val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos;
 
+    fun mapT (T as Type (_, Ts)) =
+        (map (fn T => T ->> T) (filter (is_cpo thy) Ts)) -->> (T ->> T)
+      | mapT T = T ->> T;
+
     (* declare map functions *)
     fun declare_map_const (tbind, (lhsT, rhsT)) thy =
       let
@@ -274,7 +295,7 @@
       fun unprime a = Library.unprefix "'" a;
       fun mapvar T = Free (unprime (fst (dest_TFree T)), T ->> T);
       fun map_lhs (map_const, lhsT) =
-          (lhsT, list_ccomb (map_const, map mapvar (snd (dest_Type lhsT))));
+          (lhsT, list_ccomb (map_const, map mapvar (filter (is_cpo thy) (snd (dest_Type lhsT)))));
       val tab1 = map map_lhs (map_consts ~~ map fst dom_eqns);
       val Ts = (snd o dest_Type o fst o hd) dom_eqns;
       val tab = (Ts ~~ map mapvar Ts) @ tab1;
@@ -304,9 +325,9 @@
         fun mk_goal (map_const, (lhsT, rhsT)) =
           let
             val (_, Ts) = dest_Type lhsT;
-            val map_term = list_ccomb (map_const, map mk_f Ts);
+            val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts));
           in mk_deflation map_term end;
-        val assms = (map mk_assm o snd o dest_Type o fst o hd) dom_eqns;
+        val assms = (map mk_assm o filter (is_cpo thy) o snd o dest_Type o fst o hd) dom_eqns;
         val goals = map mk_goal (map_consts ~~ dom_eqns);
         val goal = mk_trp (foldr1 HOLogic.mk_conj goals);
         val start_thms =
@@ -346,11 +367,15 @@
 
     (* register map functions in theory data *)
     local
+      fun register_map ((dname, map_name), args) =
+          Domain_Take_Proofs.add_map_function (dname, map_name, args);
       val dnames = map (fst o dest_Type o fst) dom_eqns;
       val map_names = map (fst o dest_Const) map_consts;
+      fun args (T, _) = case T of Type (_, Ts) => map (is_cpo thy) Ts | _ => [];
+      val argss = map args dom_eqns;
     in
       val thy =
-          fold Domain_Take_Proofs.add_map_function (dnames ~~ map_names) thy;
+          fold register_map (dnames ~~ map_names ~~ argss) thy;
     end;
 
     (* register deflation theorems *)
@@ -433,25 +458,35 @@
     val dbinds = map (fn (_, dbind, _, _, _) => dbind) doms;
     val morphs = map (fn (_, _, _, _, morphs) => morphs) doms;
 
+    (* determine deflation combinator arguments *)
+    fun defl_flags (vs, tbind, mx, rhs, morphs) =
+      let fun argT v = TFree (v, the (AList.lookup (op =) sorts v));
+      in map (is_bifinite thy o argT) vs end;
+    val defl_flagss = map defl_flags doms;
+
     (* declare deflation combinator constants *)
     fun declare_defl_const (vs, tbind, mx, rhs, morphs) thy =
       let
-        val defl_type = map (K deflT) vs -->> deflT;
+        fun argT v = TFree (v, the (AList.lookup (op =) sorts v));
+        val Ts = map argT vs;
+        val flags = map (is_bifinite thy) Ts;
+        val defl_type = mk_defl_type flags Ts;
         val defl_bind = Binding.suffix_name "_defl" tbind;
       in
         Sign.declare_const ((defl_bind, defl_type), NoSyn) thy
       end;
     val (defl_consts, thy) = fold_map declare_defl_const doms thy;
+    val defl_names = map (fst o dest_Const) defl_consts;
 
     (* defining equations for type combinators *)
+    val dnames = map (fst o dest_Type o fst) dom_eqns;
     val defl_tab1 = DeflData.get thy;
-    val defl_tab2 =
-      Symtab.make (map (fst o dest_Type o fst) dom_eqns ~~ defl_consts);
+    val defl_tab2 = Symtab.make (dnames ~~ (defl_names ~~ defl_flagss));
     val defl_tab' = Symtab.merge (K true) (defl_tab1, defl_tab2);
     val thy = DeflData.put defl_tab' thy;
     fun mk_defl_spec (lhsT, rhsT) =
-      mk_eqs (defl_of_typ defl_tab' lhsT,
-              defl_of_typ defl_tab' rhsT);
+      mk_eqs (defl_of_typ thy defl_tab' lhsT,
+              defl_of_typ thy defl_tab' rhsT);
     val defl_specs = map mk_defl_spec dom_eqns;
 
     (* register recursive definition of deflation combinators *)
@@ -462,9 +497,11 @@
     (* define types using deflation combinators *)
     fun make_repdef ((vs, tbind, mx, _, _), defl_const) thy =
       let
-        fun tfree a = TFree (a, the (AList.lookup (op =) sorts a))
-        val reps = map (mk_DEFL o tfree) vs;
-        val defl = list_ccomb (defl_const, reps);
+        fun tfree a = TFree (a, the (AList.lookup (op =) sorts a));
+        val Ts = map tfree vs;
+        val type_args = map Logic.mk_type (filter_out (is_bifinite thy) Ts);
+        val defl_args = map mk_DEFL (filter (is_bifinite thy) Ts);
+        val defl = list_ccomb (list_comb (defl_const, type_args), defl_args);
         val ((_, _, _, {DEFL, ...}), thy) =
           Repdef.add_repdef false NONE (tbind, map (rpair dummyS) vs, mx) defl NONE thy;
       in
@@ -560,10 +597,12 @@
         fun mk_goal ((map_const, defl_const), (T, rhsT)) =
           let
             val (_, Ts) = dest_Type T;
-            val map_term = list_ccomb (map_const, map mk_f Ts);
-            val defl_term = list_ccomb (defl_const, map mk_d Ts);
+            val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts));
+            val type_args = map Logic.mk_type (filter_out (is_bifinite thy) Ts);
+            val defl_args = map mk_d (filter (is_bifinite thy) Ts);
+            val defl_term = list_ccomb (list_comb (defl_const, type_args), defl_args);
           in isodefl_const T $ map_term $ defl_term end;
-        val assms = (map mk_assm o snd o dest_Type o fst o hd) dom_eqns;
+        val assms = (map mk_assm o filter (is_cpo thy) o snd o dest_Type o fst o hd) dom_eqns;
         val goals = map mk_goal (map_consts ~~ defl_consts ~~ dom_eqns);
         val goal = mk_trp (foldr1 HOLogic.mk_conj goals);
         val start_thms =
@@ -610,7 +649,8 @@
         (((map_const, (lhsT, _)), DEFL_thm), isodefl_thm) =
       let
         val Ts = snd (dest_Type lhsT);
-        val lhs = list_ccomb (map_const, map mk_ID Ts);
+        fun is_cpo T = Sign.of_sort thy (T, @{sort cpo});
+        val lhs = list_ccomb (map_const, map mk_ID (filter is_cpo Ts));
         val goal = mk_eqs (lhs, mk_ID lhsT);
         val tac = EVERY
           [rtac @{thm isodefl_DEFL_imp_ID} 1,
@@ -640,8 +680,9 @@
     val lub_take_lemma =
       let
         val lhs = mk_tuple (map mk_lub take_consts);
+        fun is_cpo T = Sign.of_sort thy (T, @{sort cpo});
         fun mk_map_ID (map_const, (lhsT, rhsT)) =
-          list_ccomb (map_const, map mk_ID (snd (dest_Type lhsT)));
+          list_ccomb (map_const, map mk_ID (filter is_cpo (snd (dest_Type lhsT))));
         val rhs = mk_tuple (map mk_map_ID (map_consts ~~ dom_eqns));
         val goal = mk_trp (mk_eq (lhs, rhs));
         val map_ID_thms = MapIdData.get (ProofContext.init_global thy);
--- a/src/HOLCF/Tools/Domain/domain_take_proofs.ML	Wed Oct 27 11:06:53 2010 -0700
+++ b/src/HOLCF/Tools/Domain/domain_take_proofs.ML	Wed Oct 27 11:10:36 2010 -0700
@@ -55,8 +55,8 @@
   val map_of_typ :
     theory -> (typ * term) list -> typ -> term
 
-  val add_map_function : (string * string) -> theory -> theory
-  val get_map_tab : theory -> string Symtab.table
+  val add_map_function : (string * string * bool list) -> theory -> theory
+  val get_map_tab : theory -> (string * bool list) Symtab.table
   val add_deflation_thm : thm -> theory -> theory
   val get_deflation_thms : theory -> thm list
   val setup : theory -> theory
@@ -120,7 +120,9 @@
 structure MapData = Theory_Data
 (
   (* constant names like "foo_map" *)
-  type T = string Symtab.table;
+  (* list indicates which type arguments correspond to map arguments *)
+  (* alternatively, which type arguments allow indirect recursion *)
+  type T = (string * bool list) Symtab.table;
   val empty = Symtab.empty;
   val extend = I;
   fun merge data = Symtab.merge (K true) data;
@@ -132,8 +134,8 @@
   val description = "theorems like deflation a ==> deflation (foo_map$a)"
 );
 
-fun add_map_function (tname, map_name) =
-    MapData.map (Symtab.insert (K true) (tname, map_name));
+fun add_map_function (tname, map_name, bs) =
+    MapData.map (Symtab.insert (K true) (tname, (map_name, bs)));
 
 fun add_deflation_thm thm =
     Context.theory_map (DeflMapData.add_thm thm);
@@ -188,10 +190,11 @@
           SOME m => (m, true) | NONE => map_of' T
     and map_of' (T as (Type (c, Ts))) =
         (case Symtab.lookup map_tab c of
-          SOME map_name =>
+          SOME (map_name, ds) =>
           let
-            val map_type = map auto Ts -->> auto T;
-            val (ms, bs) = map_split map_of Ts;
+            val Ts' = map snd (filter fst (ds ~~ Ts));
+            val map_type = map auto Ts' -->> auto T;
+            val (ms, bs) = map_split map_of Ts';
           in
             if exists I bs
             then (list_ccomb (Const (map_name, map_type), ms), true)
@@ -235,9 +238,6 @@
     val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos;
     val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos;
 
-    (* get table of map functions *)
-    val map_tab = MapData.get thy;
-
     fun mk_projs []      t = []
       | mk_projs (x::[]) t = [(x, t)]
       | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t);