factored out handling of sum types again
authorkrauss
Thu, 06 Dec 2007 12:23:52 +0100
changeset 25555 224a40e39457
parent 25554 082d97057e23
child 25556 8d3b7c27049b
factored out handling of sum types again
src/HOL/Tools/function_package/mutual.ML
src/HOL/Tools/function_package/sum_tree.ML
--- a/src/HOL/Tools/function_package/mutual.ML	Thu Dec 06 00:21:34 2007 +0100
+++ b/src/HOL/Tools/function_package/mutual.ML	Thu Dec 06 12:23:52 2007 +0100
@@ -27,40 +27,6 @@
 open FundefLib
 open FundefCommon
 
-(* Theory dependencies *)
-val sum_case_rules = thms "Sum_Type.sum_cases"
-val split_apply = thm "Product_Type.split"
-val projl_inl = thm "Sum_Type.Projl_Inl"
-val projr_inr = thm "Sum_Type.Projr_Inr"
-
-(* top-down access in balanced tree *)
-fun access_top_down {left, right, init} len i =
-    BalancedTree.access {left = (fn f => f o left), right = (fn f => f o right), init = I} len i init
-
-(* Sum types *)
-fun mk_sumT LT RT = Type ("+", [LT, RT])
-fun mk_sumcase TL TR T l r = Const (@{const_name "Sum_Type.sum_case"}, (TL --> T) --> (TR --> T) --> mk_sumT TL TR --> T) $ l $ r
-
-val App = curry op $
-
-fun mk_inj ST n i = 
-    access_top_down 
-    { init = (ST, I : term -> term),
-      left = (fn (T as Type ("+", [LT, RT]), inj) => (LT, inj o App (Const (@{const_name "Inl"}, LT --> T)))),
-      right =(fn (T as Type ("+", [LT, RT]), inj) => (RT, inj o App (Const (@{const_name "Inr"}, RT --> T))))} n i 
-    |> snd
-
-fun mk_proj ST n i = 
-    access_top_down 
-    { init = (ST, I : term -> term),
-      left = (fn (T as Type ("+", [LT, RT]), proj) => (LT, App (Const (@{const_name "Projl"}, T --> LT)) o proj)),
-      right =(fn (T as Type ("+", [LT, RT]), proj) => (RT, App (Const (@{const_name "Projr"}, T --> RT)) o proj))} n i
-    |> snd
-
-fun mk_sumcases T fs =
-      BalancedTree.make (fn ((f, fT), (g, gT)) => (mk_sumcase fT gT T f g, mk_sumT fT gT)) 
-                        (map (fn f => (f, domain_type (fastype_of f))) fs)
-      |> fst
 
 
 
@@ -138,8 +104,8 @@
         val dresultTs = distinct (Type.eq_type Vartab.empty) resultTs
         val n' = length dresultTs
 
-        val RST = BalancedTree.make (uncurry mk_sumT) dresultTs
-        val ST = BalancedTree.make (uncurry mk_sumT) argTs
+        val RST = BalancedTree.make (uncurry SumTree.mk_sumT) dresultTs
+        val ST = BalancedTree.make (uncurry SumTree.mk_sumT) argTs
 
         val fsum_type = ST --> RST
 
@@ -151,7 +117,7 @@
                 val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
                 val i' = find_index (fn Ta => Type.eq_type Vartab.empty (Ta, resultT)) dresultTs + 1 
 
-                val f_exp = mk_proj RST n' i' (Free fsum_var $ mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
+                val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
                 val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
 
                 val rew = (n, fold_rev lambda vars f_exp)
@@ -165,8 +131,8 @@
             let
               val MutualPart {i, i', ...} = get_part f parts
             in
-              (qs, gs, mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
-               mk_inj RST n' i' (replace_frees rews rhs)
+              (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
+               SumTree.mk_inj RST n' i' (replace_frees rews rhs)
                                |> Envir.beta_norm)
             end
 
@@ -243,7 +209,7 @@
                  (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
                  (fn _ => (LocalDefs.unfold_tac ctxt all_orig_fdefs)
                           THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
-                          THEN SIMPSET' (fn ss => simp_tac (ss addsimps [projl_inl, projr_inr])) 1)
+                          THEN SIMPSET' (fn ss => simp_tac (ss addsimps SumTree.proj_in_rules)) 1)
         |> restore_cond 
         |> export
     end
@@ -279,21 +245,21 @@
           end
           
       val Ps = map2 mk_P parts newPs
-      val case_exp = mk_sumcases HOLogic.boolT Ps
+      val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps
                      
       val induct_inst =
           forall_elim (cert case_exp) induct
-                      |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules))
+                      |> full_simplify SumTree.sumcase_split_ss
                       |> full_simplify (HOL_basic_ss addsimps all_f_defs)
           
       fun project rule (MutualPart {cargTs, i, ...}) k =
           let
             val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
-            val inj = mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
+            val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
           in
             (rule
               |> forall_elim (cert inj)
-              |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules))
+              |> full_simplify SumTree.sumcase_split_ss
               |> fold_rev (forall_intr o cert) (afs @ newPs),
              k + length cargTs)
           end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/function_package/sum_tree.ML	Thu Dec 06 12:23:52 2007 +0100
@@ -0,0 +1,44 @@
+(*  Title:      HOL/Tools/function_package/sum_tree.ML
+    ID:         $Id$
+    Author:     Alexander Krauss, TU Muenchen
+
+Some common tools for working with sum types in balanced tree form.
+*)
+
+structure SumTree =
+struct
+
+(* Theory dependencies *)
+val proj_in_rules = [thm "Sum_Type.Projl_Inl", thm "Sum_Type.Projr_Inr"]
+val sumcase_split_ss = HOL_basic_ss addsimps (@{thm "Product_Type.split"} :: @{thms "Sum_Type.sum_cases"})
+
+(* top-down access in balanced tree *)
+fun access_top_down {left, right, init} len i =
+    BalancedTree.access {left = (fn f => f o left), right = (fn f => f o right), init = I} len i init
+
+(* Sum types *)
+fun mk_sumT LT RT = Type ("+", [LT, RT])
+fun mk_sumcase TL TR T l r = Const (@{const_name "Sum_Type.sum_case"}, (TL --> T) --> (TR --> T) --> mk_sumT TL TR --> T) $ l $ r
+
+val App = curry op $
+
+fun mk_inj ST n i = 
+    access_top_down 
+    { init = (ST, I : term -> term),
+      left = (fn (T as Type ("+", [LT, RT]), inj) => (LT, inj o App (Const (@{const_name "Inl"}, LT --> T)))),
+      right =(fn (T as Type ("+", [LT, RT]), inj) => (RT, inj o App (Const (@{const_name "Inr"}, RT --> T))))} n i 
+    |> snd
+
+fun mk_proj ST n i = 
+    access_top_down 
+    { init = (ST, I : term -> term),
+      left = (fn (T as Type ("+", [LT, RT]), proj) => (LT, App (Const (@{const_name "Projl"}, T --> LT)) o proj)),
+      right =(fn (T as Type ("+", [LT, RT]), proj) => (RT, App (Const (@{const_name "Projr"}, T --> RT)) o proj))} n i
+    |> snd
+
+fun mk_sumcases T fs =
+      BalancedTree.make (fn ((f, fT), (g, gT)) => (mk_sumcase fT gT T f g, mk_sumT fT gT)) 
+                        (map (fn f => (f, domain_type (fastype_of f))) fs)
+      |> fst
+
+end