src/HOLCF/Tools/Domain/domain_isomorphism.ML
changeset 35494 45c9a8278faf
parent 35490 63f8121c6585
child 35498 5c70de748522
--- a/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Mon Mar 01 23:54:50 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Tue Mar 02 00:34:26 2010 -0800
@@ -122,6 +122,17 @@
 fun mk_deflation t =
   Const (@{const_name deflation}, Term.fastype_of t --> boolT) $ t;
 
+fun mk_lub t =
+  let
+    val T = Term.range_type (Term.fastype_of t);
+    val lub_const = Const (@{const_name lub}, (T --> boolT) --> T);
+    val UNIV_const = @{term "UNIV :: nat set"};
+    val image_type = (natT --> T) --> (natT --> boolT) --> T --> boolT;
+    val image_const = Const (@{const_name image}, image_type);
+  in
+    lub_const $ (image_const $ t $ UNIV_const)
+  end;
+
 (* splits a cterm into the right and lefthand sides of equality *)
 fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t);
 
@@ -424,7 +435,8 @@
         fun mk_goal take_const = mk_deflation (take_const $ i);
         val goal = mk_trp (foldr1 mk_conj (map mk_goal take_consts));
         val adm_rules =
-          @{thms adm_conj adm_deflation cont2cont_fst cont2cont_snd cont_id};
+          @{thms adm_conj adm_subst [OF _ adm_deflation]
+                 cont2cont_fst cont2cont_snd cont_id};
         val bottom_rules =
           take_0_thms @ @{thms deflation_UU simp_thms};
         val deflation_rules =
@@ -436,8 +448,10 @@
          EVERY
           [rtac @{thm nat.induct} 1,
            simp_tac (HOL_basic_ss addsimps bottom_rules) 1,
-           simp_tac (HOL_basic_ss addsimps take_Suc_thms) 1,
-           REPEAT (resolve_tac deflation_rules 1 ORELSE atac 1)])
+           asm_simp_tac (HOL_basic_ss addsimps take_Suc_thms) 1,
+           REPEAT (etac @{thm conjE} 1
+                   ORELSE resolve_tac deflation_rules 1
+                   ORELSE atac 1)])
       end;
     fun conjuncts [] thm = []
       | conjuncts (n::[]) thm = [(n, thm)]
@@ -795,7 +809,8 @@
         val start_thms =
           @{thm split_def} :: map_apply_thms;
         val adm_rules =
-          @{thms adm_conj adm_deflation cont2cont_fst cont2cont_snd cont_id};
+          @{thms adm_conj adm_subst [OF _ adm_deflation]
+                 cont2cont_fst cont2cont_snd cont_id};
         val bottom_rules =
           @{thms fst_strict snd_strict deflation_UU simp_thms};
         val deflation_rules =
@@ -821,115 +836,58 @@
         (conjuncts deflation_map_binds deflation_map_thm);
     val thy = DeflMapData.map (fold Thm.add_thm deflation_map_thms) thy;
 
-    (* define copy combinators *)
-    val new_dts =
-      map (apsnd (map (fst o dest_TFree)) o dest_Type o fst) dom_eqns;
-    val copy_arg_type = mk_tupleT (map (fn (T, _) => T ->> T) dom_eqns);
-    val copy_arg = Free ("f", copy_arg_type);
-    val copy_args =
-      let fun mk_copy_args [] t = []
-            | mk_copy_args (_::[]) t = [t]
-            | mk_copy_args (_::xs) t =
-                mk_fst t :: mk_copy_args xs (mk_snd t);
-      in mk_copy_args doms copy_arg end;
-    fun copy_of_dtyp (T, dt) =
-        if Datatype_Aux.is_rec_type dt
-        then copy_of_dtyp' (T, dt)
-        else mk_ID T
-    and copy_of_dtyp' (T, Datatype_Aux.DtRec i) = nth copy_args i
-      | copy_of_dtyp' (T, Datatype_Aux.DtTFree a) = mk_ID T
-      | copy_of_dtyp' (T, Datatype_Aux.DtType (c, ds)) =
-        case Symtab.lookup map_tab' c of
-          SOME f =>
-          list_ccomb
-            (Const (f, mapT T), map copy_of_dtyp (snd (dest_Type T) ~~ ds))
-        | NONE =>
-          (warning ("copy_of_dtyp: unknown type constructor " ^ c); mk_ID T);
-    fun define_copy ((tbind, (rep_const, abs_const)), (lhsT, rhsT)) thy =
+    (* definitions and proofs related to take functions *)
+    val (take_info, thy) =
+      define_take_functions (dom_binds ~~ iso_infos) thy;
+    val {take_consts, take_defs, chain_take_thms, take_0_thms,
+         take_Suc_thms, deflation_take_thms} = take_info;
+
+    (* least-upper-bound lemma for take functions *)
+    val lub_take_lemma =
       let
-        val copy_type = copy_arg_type ->> (lhsT ->> lhsT);
-        val copy_bind = Binding.suffix_name "_copy" tbind;
-        val (copy_const, thy) = thy |>
-          Sign.declare_const ((copy_bind, copy_type), NoSyn);
-        val dtyp = Datatype_Aux.dtyp_of_typ new_dts rhsT;
-        val body = copy_of_dtyp (rhsT, dtyp);
-        val comp = mk_cfcomp (abs_const, mk_cfcomp (body, rep_const));
-        val rhs = big_lambda copy_arg comp;
-        val eqn = Logic.mk_equals (copy_const, rhs);
-        val (copy_def, thy) =
-          thy
-          |> Sign.add_path (Binding.name_of tbind)
-          |> yield_singleton (PureThy.add_defs false o map Thm.no_attributes)
-              (Binding.name "copy_def", eqn)
-          ||> Sign.parent_path;
-      in ((copy_const, copy_def), thy) end;
-    val ((copy_consts, copy_defs), thy) = thy
-      |> fold_map define_copy (dom_binds ~~ rep_abs_consts ~~ dom_eqns)
-      |>> ListPair.unzip;
-
-    (* define combined copy combinator *)
-    val ((c_const, c_def_thms), thy) =
-      if length doms = 1
-      then ((hd copy_consts, []), thy)
-      else
-        let
-          val c_type = copy_arg_type ->> copy_arg_type;
-          val c_name = space_implode "_" (map Binding.name_of dom_binds);
-          val c_bind = Binding.name (c_name ^ "_copy");
-          val c_body =
-              mk_tuple (map (mk_capply o rpair copy_arg) copy_consts);
-          val c_rhs = big_lambda copy_arg c_body;
-          val (c_const, thy) =
-            Sign.declare_const ((c_bind, c_type), NoSyn) thy;
-          val c_eqn = Logic.mk_equals (c_const, c_rhs);
-          val (c_def_thms, thy) =
-            thy
-            |> Sign.add_path c_name
-            |> (PureThy.add_defs false o map Thm.no_attributes)
-                [(Binding.name "copy_def", c_eqn)]
-            ||> Sign.parent_path;
-        in ((c_const, c_def_thms), thy) end;
-
-    (* fixed-point lemma for combined copy combinator *)
-    val fix_copy_lemma =
-      let
-        fun mk_map_ID (map_const, (T, rhsT)) =
-          list_ccomb (map_const, map mk_ID (snd (dest_Type T)));
+        val lhs = mk_tuple (map mk_lub take_consts);
+        fun mk_map_ID (map_const, (lhsT, rhsT)) =
+          list_ccomb (map_const, map mk_ID (snd (dest_Type lhsT)));
         val rhs = mk_tuple (map mk_map_ID (map_consts ~~ dom_eqns));
-        val goal = mk_eqs (mk_fix c_const, rhs);
-        val rules =
-          [@{thm pair_collapse}, @{thm split_def}]
-          @ map_apply_thms
-          @ c_def_thms @ copy_defs
-          @ MapIdData.get thy;
-        val tac = simp_tac (beta_ss addsimps rules) 1;
+        val goal = mk_trp (mk_eq (lhs, rhs));
+        val start_rules =
+            @{thms thelub_Pair [symmetric] ch2ch_Pair} @ chain_take_thms
+            @ @{thms pair_collapse split_def}
+            @ map_apply_thms @ MapIdData.get thy;
+        val rules0 =
+            @{thms iterate_0 Pair_strict} @ take_0_thms;
+        val rules1 =
+            @{thms iterate_Suc Pair_fst_snd_eq fst_conv snd_conv}
+            @ take_Suc_thms;
+        val tac =
+            EVERY
+            [simp_tac (HOL_basic_ss addsimps start_rules) 1,
+             simp_tac (HOL_basic_ss addsimps @{thms fix_def2}) 1,
+             rtac @{thm lub_eq} 1,
+             rtac @{thm nat.induct} 1,
+             simp_tac (HOL_basic_ss addsimps rules0) 1,
+             asm_full_simp_tac (beta_ss addsimps rules1) 1];
       in
         Goal.prove_global thy [] [] goal (K tac)
       end;
 
-    (* prove reach lemmas *)
-    val reach_thm_projs =
-      let fun mk_projs []      t = []
-            | mk_projs (x::[]) t = [(x, t)]
-            | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t);
-      in mk_projs dom_binds (mk_fix c_const) end;
-    fun prove_reach_thm (((bind, t), map_ID_thm), (lhsT, rhsT)) thy =
+    (* prove lub of take equals ID *)
+    fun prove_lub_take (((bind, take_const), map_ID_thm), (lhsT, rhsT)) thy =
       let
-        val x = Free ("x", lhsT);
-        val goal = mk_eqs (mk_capply (t, x), x);
-        val rules =
-          fix_copy_lemma :: map_ID_thm :: @{thms fst_conv snd_conv ID1};
-        val tac = simp_tac (HOL_basic_ss addsimps rules) 1;
-        val reach_thm = Goal.prove_global thy [] [] goal (K tac);
+        val i = Free ("i", natT);
+        val goal = mk_eqs (mk_lub (lambda i (take_const $ i)), mk_ID lhsT);
+        val tac =
+            EVERY
+            [rtac @{thm trans} 1, rtac map_ID_thm 2,
+             cut_facts_tac [lub_take_lemma] 1,
+             REPEAT (etac @{thm Pair_inject} 1), atac 1];
+        val lub_take_thm = Goal.prove_global thy [] [] goal (K tac);
       in
-        thy
-        |> Sign.add_path (Binding.name_of bind)
-        |> yield_singleton (PureThy.add_thms o map Thm.no_attributes)
-            (Binding.name "reach", reach_thm)
-        ||> Sign.parent_path
+        add_qualified_thm "lub_take" (Binding.name_of bind, lub_take_thm) thy
       end;
-    val (reach_thms, thy) = thy |>
-      fold_map prove_reach_thm (reach_thm_projs ~~ map_ID_thms ~~ dom_eqns);
+    val (lub_take_thms, thy) =
+        fold_map prove_lub_take
+          (dom_binds ~~ take_consts ~~ map_ID_thms ~~ dom_eqns) thy;
 
   in
     (iso_infos, thy)