src/HOL/Tools/BNF/bnf_fp_util.ML
changeset 55966 972f0aa7091b
parent 55945 e96383acecf9
child 55968 94242fa87638
--- a/src/HOL/Tools/BNF/bnf_fp_util.ML	Thu Mar 06 22:15:01 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_util.ML	Fri Mar 07 01:02:21 2014 +0100
@@ -128,7 +128,7 @@
   val split_conj_prems: int -> thm -> thm
 
   val mk_sumTN: typ list -> typ
-  val mk_sumTN_balanced: typ list -> typ
+  val mk_sumprodT_balanced: typ list list -> typ
 
   val mk_proj: typ -> int -> int -> term
 
@@ -143,7 +143,7 @@
 
   val mk_Inl: typ -> term -> term
   val mk_Inr: typ -> term -> term
-  val mk_InN: typ list -> term -> int -> term
+  val mk_tuple_balanced: term list -> term
   val mk_absumprod: typ -> term -> int -> int -> term list -> term
 
   val dest_sumT: typ -> typ * typ
@@ -155,7 +155,6 @@
   val mk_If: term -> term -> term -> term
   val mk_union: term * term -> term
 
-  val mk_sumEN: int -> thm
   val mk_absumprodE: thm -> int list -> thm
 
   val mk_sum_caseN: int -> int -> thm
@@ -331,7 +330,7 @@
 val selN = "sel"
 val sel_corecN = selN ^ "_" ^ corecN
 
-fun co_prefix fp = (if fp = Greatest_FP then "co" else "");
+fun co_prefix fp = case_fp fp "" "co";
 
 fun add_components_of_typ (Type (s, Ts)) =
     cons (Long_Name.base_name s) #> fold_rev add_components_of_typ Ts
@@ -343,16 +342,20 @@
 
 val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
 
-(* TODO: move something like this to "HOLogic"? *)
-fun dest_tupleT 0 @{typ unit} = []
-  | dest_tupleT 1 T = [T]
-  | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T';
+fun dest_tupleT_balanced 0 @{typ unit} = []
+  | dest_tupleT_balanced n T = Balanced_Tree.dest HOLogic.dest_prodT n T;
 
-fun dest_absumprodT absT repT n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o mk_repT absT repT;
+fun dest_absumprodT absT repT n ms =
+  map2 dest_tupleT_balanced ms o dest_sumTN_balanced n o mk_repT absT repT;
 
 val mk_sumTN = Library.foldr1 mk_sumT;
 val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
 
+fun mk_tupleT_balanced [] = HOLogic.unitT
+  | mk_tupleT_balanced Ts = Balanced_Tree.make HOLogic.mk_prodT Ts;
+
+val mk_sumprodT_balanced = mk_sumTN_balanced o map mk_tupleT_balanced;
+
 fun mk_proj T n k =
   let val (binders, _) = strip_typeN n T in
     fold_rev (fn T => fn t => Abs (Name.uu, T, t)) binders (Bound (n - k - 1))
@@ -371,11 +374,7 @@
 fun Inr_const LT RT = Const (@{const_name Inr}, RT --> mk_sumT (LT, RT));
 fun mk_Inr LT t = Inr_const LT (fastype_of t) $ t;
 
-fun mk_InN [_] t 1 = t
-  | mk_InN (_ :: Ts) t 1 = mk_Inl (mk_sumTN Ts) t
-  | mk_InN (LT :: Ts) t m = mk_Inr LT (mk_InN Ts t (m - 1))
-  | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
-
+(* FIXME: reuse "mk_inj" in function package *)
 fun mk_InN_balanced sum_T n t k =
   let
     fun repair_types T (Const (s as @{const_name Inl}, _) $ t) = repair_inj_types T s fst t
@@ -390,9 +389,12 @@
     |> repair_types sum_T
   end;
 
+fun mk_tuple_balanced [] = HOLogic.unit
+  | mk_tuple_balanced ts = Balanced_Tree.make HOLogic.mk_prod ts;
+
 fun mk_absumprod absT abs0 n k ts =
   let val abs = mk_abs absT abs0;
-  in abs $ mk_InN_balanced (domain_type (fastype_of abs)) n (HOLogic.mk_tuple ts) k end;
+  in abs $ mk_InN_balanced (domain_type (fastype_of abs)) n (mk_tuple_balanced ts) k end;
 
 fun mk_case_sum (f, g) =
   let
@@ -434,24 +436,26 @@
       if i = n then th else split n (i + 1) (conjI RSN (i, th)) handle THM _ => th;
   in split limit 1 th end;
 
-fun mk_sumEN 1 = @{thm one_pointE}
-  | mk_sumEN 2 = @{thm sumE}
-  | mk_sumEN n =
-    (fold (fn i => fn thm => @{thm obj_sumE_f} RSN (i, thm)) (2 upto n - 1) @{thm obj_sumE}) OF
-      replicate n (impI RS allI);
-
 fun mk_obj_sumEN_balanced n =
   Balanced_Tree.make (fn (thm1, thm2) => thm1 RSN (1, thm2 RSN (2, @{thm obj_sumE_f})))
     (replicate n asm_rl);
 
-fun mk_tupled_allIN 0 = @{thm unit_all_impI}
-  | mk_tupled_allIN 1 = @{thm impI[THEN allI]}
-  | mk_tupled_allIN 2 = @{thm prod_all_impI} (*optimization*)
-  | mk_tupled_allIN n = mk_tupled_allIN (n - 1) RS @{thm prod_all_impI_step};
+fun mk_tupled_allIN_balanced 0 = @{thm unit_all_impI}
+  | mk_tupled_allIN_balanced n =
+    let
+      val (tfrees, _) = BNF_Util.mk_TFrees n @{context};
+      val T = mk_tupleT_balanced tfrees;
+    in
+      @{thm asm_rl[of "ALL x. P x --> Q x" for P Q]}
+      |> Drule.instantiate' [SOME (ctyp_of @{theory} T)] []
+      |> Raw_Simplifier.rewrite_goals_rule @{context} @{thms split_paired_All[THEN eq_reflection]}
+      |> (fn thm => impI RS funpow n (fn th => allI RS th) thm)
+      |> Thm.varifyT_global
+    end;
 
 fun mk_absumprodE type_definition ms =
   let val n = length ms in
-    mk_obj_sumEN_balanced n OF map mk_tupled_allIN ms RS
+    mk_obj_sumEN_balanced n OF map mk_tupled_allIN_balanced ms RS
       (type_definition RS @{thm type_copy_obj_one_point_absE})
   end;
 
@@ -519,16 +523,16 @@
     map_cong0s =
   let
     val n = length sym_map_comps;
-    val rewrite_comp_comp2 = fp_case fp @{thm rewriteR_comp_comp2} @{thm rewriteL_comp_comp2};
-    val rewrite_comp_comp = fp_case fp @{thm rewriteR_comp_comp} @{thm rewriteL_comp_comp};
-    val map_cong_passive_args1 = replicate m (fp_case fp @{thm id_comp} @{thm comp_id} RS fun_cong);
+    val rewrite_comp_comp2 = case_fp fp @{thm rewriteR_comp_comp2} @{thm rewriteL_comp_comp2};
+    val rewrite_comp_comp = case_fp fp @{thm rewriteR_comp_comp} @{thm rewriteL_comp_comp};
+    val map_cong_passive_args1 = replicate m (case_fp fp @{thm id_comp} @{thm comp_id} RS fun_cong);
     val map_cong_active_args1 = replicate n (if is_rec
-      then fp_case fp @{thm convol_o} @{thm o_case_sum} RS fun_cong
+      then case_fp fp @{thm convol_o} @{thm o_case_sum} RS fun_cong
       else refl);
-    val map_cong_passive_args2 = replicate m (fp_case fp @{thm comp_id} @{thm id_comp} RS fun_cong);
+    val map_cong_passive_args2 = replicate m (case_fp fp @{thm comp_id} @{thm id_comp} RS fun_cong);
     val map_cong_active_args2 = replicate n (if is_rec
-      then fp_case fp @{thm map_prod_o_convol_id} @{thm case_sum_o_map_sum_id}
-      else fp_case fp @{thm id_comp} @{thm comp_id} RS fun_cong);
+      then case_fp fp @{thm map_prod_o_convol_id} @{thm case_sum_o_map_sum_id}
+      else case_fp fp @{thm id_comp} @{thm comp_id} RS fun_cong);
     fun mk_map_congs passive active = map (fn thm => thm OF (passive @ active) RS ext) map_cong0s;
     val map_cong1s = mk_map_congs map_cong_passive_args1 map_cong_active_args1;
     val map_cong2s = mk_map_congs map_cong_passive_args2 map_cong_active_args2;
@@ -543,7 +547,7 @@
           (mk_trans rewrite1 (mk_sym rewrite2)))
       xtor_maps xtor_un_folds rewrite1s rewrite2s;
   in
-    split_conj_thm (un_fold_unique OF map (fp_case fp I mk_sym) unique_prems)
+    split_conj_thm (un_fold_unique OF map (case_fp fp I mk_sym) unique_prems)
   end;
 
 fun mk_strong_coinduct_thm coind rel_eqs rel_monos mk_vimage2p ctxt =