separate map-related code into new function define_map_functions
authorhuffman
Sun, 14 Mar 2010 14:10:05 -0700
changeset 35791 dc175fe29326
parent 35790 a9507cd84326
child 35792 48cd2261817b
separate map-related code into new function define_map_functions
src/HOLCF/Tools/Domain/domain_isomorphism.ML
--- a/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Sun Mar 14 15:50:17 2010 +0100
+++ b/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Sun Mar 14 14:10:05 2010 -0700
@@ -13,6 +13,17 @@
       (Domain_Take_Proofs.iso_info list
        * Domain_Take_Proofs.take_induct_info) * theory
 
+  val define_map_functions :
+      (binding * Domain_Take_Proofs.iso_info) list ->
+      theory ->
+      {
+        map_consts : term list,
+        map_apply_thms : thm list,
+        map_unfold_thms : thm list,
+        deflation_map_thms : thm list
+      }
+      * theory
+
   val domain_isomorphism_cmd :
     (string list * binding * mixfix * string * (binding * binding) option) list
       -> theory -> theory
@@ -241,6 +252,129 @@
       ((Binding.qualified true name dbind, thm), []);
 
 (******************************************************************************)
+(*************************** defining map functions ***************************)
+(******************************************************************************)
+
+fun define_map_functions
+    (spec : (binding * Domain_Take_Proofs.iso_info) list)
+    (thy : theory) =
+  let
+
+    (* retrieve components of spec *)
+    val dbinds = map fst spec;
+    val iso_infos = map snd spec;
+    val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos;
+    val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos;
+
+    (* declare map functions *)
+    fun declare_map_const (tbind, (lhsT, rhsT)) thy =
+      let
+        val map_type = mapT lhsT;
+        val map_bind = Binding.suffix_name "_map" tbind;
+      in
+        Sign.declare_const ((map_bind, map_type), NoSyn) thy
+      end;
+    val (map_consts, thy) = thy |>
+      fold_map declare_map_const (dbinds ~~ dom_eqns);
+
+    (* defining equations for map functions *)
+    local
+      fun unprime a = Library.unprefix "'" a;
+      fun mapvar T = Free (unprime (fst (dest_TFree T)), T ->> T);
+      fun map_lhs (map_const, lhsT) =
+          (lhsT, list_ccomb (map_const, map mapvar (snd (dest_Type lhsT))));
+      val tab1 = map map_lhs (map_consts ~~ map fst dom_eqns);
+      val Ts = (snd o dest_Type o fst o hd) dom_eqns;
+      val tab = (Ts ~~ map mapvar Ts) @ tab1;
+      fun mk_map_spec (((rep_const, abs_const), map_const), (lhsT, rhsT)) =
+        let
+          val lhs = Domain_Take_Proofs.map_of_typ thy tab lhsT;
+          val body = Domain_Take_Proofs.map_of_typ thy tab rhsT;
+          val rhs = mk_cfcomp (abs_const, mk_cfcomp (body, rep_const));
+        in mk_eqs (lhs, rhs) end;
+    in
+      val map_specs =
+          map mk_map_spec (rep_abs_consts ~~ map_consts ~~ dom_eqns);
+    end;
+
+    (* register recursive definition of map functions *)
+    val map_binds = map (Binding.suffix_name "_map") dbinds;
+    val ((map_apply_thms, map_unfold_thms), thy) =
+      add_fixdefs (map_binds ~~ map_specs) thy;
+
+    (* prove deflation theorems for map functions *)
+    val deflation_abs_rep_thms = map deflation_abs_rep iso_infos;
+    val deflation_map_thm =
+      let
+        fun unprime a = Library.unprefix "'" a;
+        fun mk_f T = Free (unprime (fst (dest_TFree T)), T ->> T);
+        fun mk_assm T = mk_trp (mk_deflation (mk_f T));
+        fun mk_goal (map_const, (lhsT, rhsT)) =
+          let
+            val (_, Ts) = dest_Type lhsT;
+            val map_term = list_ccomb (map_const, map mk_f Ts);
+          in mk_deflation map_term end;
+        val assms = (map mk_assm o snd o dest_Type o fst o hd) dom_eqns;
+        val goals = map mk_goal (map_consts ~~ dom_eqns);
+        val goal = mk_trp (foldr1 HOLogic.mk_conj goals);
+        val start_thms =
+          @{thm split_def} :: map_apply_thms;
+        val adm_rules =
+          @{thms adm_conj adm_subst [OF _ adm_deflation]
+                 cont2cont_fst cont2cont_snd cont_id};
+        val bottom_rules =
+          @{thms fst_strict snd_strict deflation_UU simp_thms};
+        val deflation_rules =
+          @{thms conjI deflation_ID}
+          @ deflation_abs_rep_thms
+          @ Domain_Take_Proofs.get_deflation_thms thy;
+      in
+        Goal.prove_global thy [] assms goal (fn {prems, ...} =>
+         EVERY
+          [simp_tac (HOL_basic_ss addsimps start_thms) 1,
+           rtac @{thm fix_ind} 1,
+           REPEAT (resolve_tac adm_rules 1),
+           simp_tac (HOL_basic_ss addsimps bottom_rules) 1,
+           simp_tac beta_ss 1,
+           simp_tac (HOL_basic_ss addsimps @{thms fst_conv snd_conv}) 1,
+           REPEAT (etac @{thm conjE} 1),
+           REPEAT (resolve_tac (deflation_rules @ prems) 1 ORELSE atac 1)])
+      end;
+    fun conjuncts [] thm = []
+      | conjuncts (n::[]) thm = [(n, thm)]
+      | conjuncts (n::ns) thm = let
+          val thmL = thm RS @{thm conjunct1};
+          val thmR = thm RS @{thm conjunct2};
+        in (n, thmL):: conjuncts ns thmR end;
+    val deflation_map_binds = dbinds |>
+        map (Binding.prefix_name "deflation_" o Binding.suffix_name "_map");
+    val (deflation_map_thms, thy) = thy |>
+      (PureThy.add_thms o map (Thm.no_attributes o apsnd Drule.export_without_context))
+        (conjuncts deflation_map_binds deflation_map_thm);
+
+    (* register map functions in theory data *)
+    local
+      fun register_map ((dname, map_name), defl_thm) =
+          Domain_Take_Proofs.add_map_function (dname, map_name, defl_thm);
+      val dnames = map (fst o dest_Type o fst) dom_eqns;
+      val map_names = map (fst o dest_Const) map_consts;
+    in
+      val thy =
+          fold register_map (dnames ~~ map_names ~~ deflation_map_thms) thy;
+    end;
+
+    val result =
+      {
+        map_consts = map_consts,
+        map_apply_thms = map_apply_thms,
+        map_unfold_thms = map_unfold_thms,
+        deflation_map_thms = deflation_map_thms
+      }
+  in
+    (result, thy)
+  end;
+
+(******************************************************************************)
 (******************************* main function ********************************)
 (******************************************************************************)
 
@@ -417,41 +551,11 @@
         map mk_info (dom_eqns ~~ rep_abs_consts ~~ iso_thms)
       end
 
-    (* declare map functions *)
-    fun declare_map_const (tbind, (lhsT, rhsT)) thy =
-      let
-        val map_type = mapT lhsT;
-        val map_bind = Binding.suffix_name "_map" tbind;
-      in
-        Sign.declare_const ((map_bind, map_type), NoSyn) thy
-      end;
-    val (map_consts, thy) = thy |>
-      fold_map declare_map_const (dbinds ~~ dom_eqns);
-
-    (* defining equations for map functions *)
-    local
-      fun unprime a = Library.unprefix "'" a;
-      fun mapvar T = Free (unprime (fst (dest_TFree T)), T ->> T);
-      fun map_lhs (map_const, lhsT) =
-          (lhsT, list_ccomb (map_const, map mapvar (snd (dest_Type lhsT))));
-      val tab1 = map map_lhs (map_consts ~~ map fst dom_eqns);
-      val Ts = (snd o dest_Type o fst o hd) dom_eqns;
-      val tab = (Ts ~~ map mapvar Ts) @ tab1;
-      fun mk_map_spec (((rep_const, abs_const), map_const), (lhsT, rhsT)) =
-        let
-          val lhs = Domain_Take_Proofs.map_of_typ thy tab lhsT;
-          val body = Domain_Take_Proofs.map_of_typ thy tab rhsT;
-          val rhs = mk_cfcomp (abs_const, mk_cfcomp (body, rep_const));
-        in mk_eqs (lhs, rhs) end;
-    in
-      val map_specs =
-          map mk_map_spec (rep_abs_consts ~~ map_consts ~~ dom_eqns);
-    end;
-
-    (* register recursive definition of map functions *)
-    val map_binds = map (Binding.suffix_name "_map") dbinds;
-    val ((map_apply_thms, map_unfold_thms), thy) =
-      add_fixdefs (map_binds ~~ map_specs) thy;
+    (* definitions and proofs related to map functions *)
+    val (map_info, thy) =
+        define_map_functions (dbinds ~~ iso_infos) thy;
+    val { map_consts, map_apply_thms, map_unfold_thms,
+          deflation_map_thms } = map_info;
 
     (* prove isodefl rules for map functions *)
     val isodefl_thm =
@@ -531,61 +635,6 @@
         (map_ID_binds ~~ map_ID_thms);
     val thy = MapIdData.map (fold Thm.add_thm map_ID_thms) thy;
 
-    (* prove deflation theorems for map functions *)
-    val deflation_abs_rep_thms = map deflation_abs_rep iso_infos;
-    val deflation_map_thm =
-      let
-        fun unprime a = Library.unprefix "'" a;
-        fun mk_f T = Free (unprime (fst (dest_TFree T)), T ->> T);
-        fun mk_assm T = mk_trp (mk_deflation (mk_f T));
-        fun mk_goal (map_const, (lhsT, rhsT)) =
-          let
-            val (_, Ts) = dest_Type lhsT;
-            val map_term = list_ccomb (map_const, map mk_f Ts);
-          in mk_deflation map_term end;
-        val assms = (map mk_assm o snd o dest_Type o fst o hd) dom_eqns;
-        val goals = map mk_goal (map_consts ~~ dom_eqns);
-        val goal = mk_trp (foldr1 HOLogic.mk_conj goals);
-        val start_thms =
-          @{thm split_def} :: map_apply_thms;
-        val adm_rules =
-          @{thms adm_conj adm_subst [OF _ adm_deflation]
-                 cont2cont_fst cont2cont_snd cont_id};
-        val bottom_rules =
-          @{thms fst_strict snd_strict deflation_UU simp_thms};
-        val deflation_rules =
-          @{thms conjI deflation_ID}
-          @ deflation_abs_rep_thms
-          @ Domain_Take_Proofs.get_deflation_thms thy;
-      in
-        Goal.prove_global thy [] assms goal (fn {prems, ...} =>
-         EVERY
-          [simp_tac (HOL_basic_ss addsimps start_thms) 1,
-           rtac @{thm fix_ind} 1,
-           REPEAT (resolve_tac adm_rules 1),
-           simp_tac (HOL_basic_ss addsimps bottom_rules) 1,
-           simp_tac beta_ss 1,
-           simp_tac (HOL_basic_ss addsimps @{thms fst_conv snd_conv}) 1,
-           REPEAT (etac @{thm conjE} 1),
-           REPEAT (resolve_tac (deflation_rules @ prems) 1 ORELSE atac 1)])
-      end;
-    val deflation_map_binds = dbinds |>
-        map (Binding.prefix_name "deflation_" o Binding.suffix_name "_map");
-    val (deflation_map_thms, thy) = thy |>
-      (PureThy.add_thms o map (Thm.no_attributes o apsnd Drule.export_without_context))
-        (conjuncts deflation_map_binds deflation_map_thm);
-
-    (* register map functions in theory data *)
-    local
-      fun register_map ((dname, map_name), defl_thm) =
-          Domain_Take_Proofs.add_map_function (dname, map_name, defl_thm);
-      val dnames = map (fst o dest_Type o fst) dom_eqns;
-      val map_names = map (fst o dest_Const) map_consts;
-    in
-      val thy =
-          fold register_map (dnames ~~ map_names ~~ deflation_map_thms) thy;
-    end;
-
     (* definitions and proofs related to take functions *)
     val (take_info, thy) =
         Domain_Take_Proofs.define_take_functions