implemented and use "mk_sum_casesN_balanced"
authorblanchet
Mon, 10 Sep 2012 18:29:55 +0200
changeset 49264 9059e0dbdbc1
parent 49263 669a820ef213
child 49265 059aa3088ae3
implemented and use "mk_sum_casesN_balanced"
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:52:01 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 18:29:55 2012 +0200
@@ -50,31 +50,6 @@
 fun mk_uncurried2_fun f xss =
   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
 
-val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
-val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
-
-fun mk_InN_balanced sum_T n t k =
-  let
-    fun repair_types T (Const (s as @{const_name Inl}, _) $ t) = repair_inj_types T s fst t
-      | repair_types T (Const (s as @{const_name Inr}, _) $ t) = repair_inj_types T s snd t
-      | repair_types _ t = t
-    and repair_inj_types T s get t =
-      let val T' = get (dest_sumT T) in
-        Const (s, T' --> T) $ repair_types T' t
-      end;
-  in
-    Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} n k
-    |> repair_types sum_T
-  end;
-
-val mk_sum_caseN_balanced = Balanced_Tree.make mk_sum_case;
-
-fun mk_sumEN_balanced 1 = @{thm one_pointE}
-  | mk_sumEN_balanced 2 = @{thm sumE} (*optimization*)
-  | mk_sumEN_balanced n =
-    Balanced_Tree.make (fn (thm1, thm2) => thm1 RSN (1, thm2 RSN (2, @{thm obj_sumE_f})))
-      (replicate n asm_rl) OF (replicate n (impI RS allI)) RS @{thm obj_one_pointE};
-
 fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
 
 fun tack z_name (c, v) f =
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML	Mon Sep 10 17:52:01 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML	Mon Sep 10 18:29:55 2012 +0200
@@ -22,10 +22,11 @@
 
 open BNF_Tactics
 open BNF_Util
+open BNF_FP_Util
 
 fun mk_case_tac ctxt n k m case_def ctr_def unf_fld =
   Local_Defs.unfold_tac ctxt [case_def, ctr_def, unf_fld] THEN
-  (rtac (BNF_FP_Util.mk_sum_casesN n k RS ssubst) THEN'
+  (rtac (mk_sum_casesN_balanced n k RS ssubst) THEN'
    REPEAT_DETERM_N (Int.max (0, m - 1)) o rtac (@{thm split} RS ssubst) THEN'
    rtac refl) 1;
 
--- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Mon Sep 10 17:52:01 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Mon Sep 10 18:29:55 2012 +0200
@@ -82,6 +82,7 @@
   val split_conj_prems: int -> thm -> thm
 
   val mk_sumTN: typ list -> typ
+  val mk_sumTN_balanced: typ list -> typ
 
   val Inl_const: typ -> typ -> term
   val Inr_const: typ -> typ -> term
@@ -89,18 +90,23 @@
   val mk_Inl: typ -> term -> term
   val mk_Inr: typ -> term -> term
   val mk_InN: typ list -> term -> int -> term
+  val mk_InN_balanced: typ -> int -> term -> int -> term
   val mk_sum_case: term * term -> term
   val mk_sum_caseN: term list -> term
+  val mk_sum_caseN_balanced: term list -> term
 
   val dest_sumT: typ -> typ * typ
   val dest_sumTN: int -> typ -> typ list
+  val dest_sumTN_balanced: int -> typ -> typ list
   val dest_tupleT: int -> typ -> typ list
 
   val mk_Field: term -> term
   val mk_union: term * term -> term
 
   val mk_sumEN: int -> thm
+  val mk_sumEN_balanced: int -> thm
   val mk_sum_casesN: int -> int -> thm
+  val mk_sum_casesN_balanced: int -> int -> thm
 
   val mk_tactics: 'a -> 'a -> 'a -> 'a list -> 'a -> 'a -> 'a list -> 'a -> 'a -> 'a list
 
@@ -200,7 +206,20 @@
 val set_inclN = "set_incl"
 val set_set_inclN = "set_set_incl"
 
-fun mk_sumTN Ts = Library.foldr1 mk_sumT Ts;
+fun dest_sumT (Type (@{type_name sum}, [T, T'])) = (T, T');
+
+fun dest_sumTN 1 T = [T]
+  | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T';
+
+val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
+
+(* TODO: move something like this to "HOLogic"? *)
+fun dest_tupleT 0 @{typ unit} = []
+  | dest_tupleT 1 T = [T]
+  | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T';
+
+val mk_sumTN = Library.foldr1 mk_sumT;
+val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
 
 fun Inl_const LT RT = Const (@{const_name Inl}, LT --> mk_sumT (LT, RT));
 fun mk_Inl RT t = Inl_const (fastype_of t) RT $ t;
@@ -213,6 +232,20 @@
   | mk_InN (LT :: Ts) t m = mk_Inr LT (mk_InN Ts t (m - 1))
   | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
 
+fun mk_InN_balanced sum_T n t k =
+  let
+    fun repair_types T (Const (s as @{const_name Inl}, _) $ t) = repair_inj_types T s fst t
+      | repair_types T (Const (s as @{const_name Inr}, _) $ t) = repair_inj_types T s snd t
+      | repair_types _ t = t
+    and repair_inj_types T s get t =
+      let val T' = get (dest_sumT T) in
+        Const (s, T' --> T) $ repair_types T' t
+      end;
+  in
+    Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} n k
+    |> repair_types sum_T
+  end;
+
 fun mk_sum_case (f, g) =
   let
     val fT = fastype_of f;
@@ -222,18 +255,8 @@
       fT --> gT --> mk_sumT (domain_type fT, domain_type gT) --> range_type fT) $ f $ g
   end;
 
-fun mk_sum_caseN [f] = f
-  | mk_sum_caseN (f :: fs) = mk_sum_case (f, mk_sum_caseN fs);
-
-fun dest_sumT (Type (@{type_name sum}, [T, T'])) = (T, T');
-
-fun dest_sumTN 1 T = [T]
-  | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T';
-
-(* TODO: move something like this to "HOLogic"? *)
-fun dest_tupleT 0 @{typ unit} = []
-  | dest_tupleT 1 T = [T]
-  | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T';
+val mk_sum_caseN = Library.foldr1 mk_sum_case;
+val mk_sum_caseN_balanced = Balanced_Tree.make mk_sum_case;
 
 fun mk_Field r =
   let val T = fst (dest_relT (fastype_of r));
@@ -260,10 +283,24 @@
     (fold (fn i => fn thm => @{thm obj_sum_step} RSN (i, thm)) (2 upto n - 1) @{thm obj_sumE}) OF
       replicate n (impI RS allI);
 
-fun mk_sum_casesN 1 1 = @{thm refl}
+fun mk_sumEN_balanced 1 = @{thm one_pointE}
+  | mk_sumEN_balanced 2 = @{thm sumE} (*optimization*)
+  | mk_sumEN_balanced n =
+    Balanced_Tree.make (fn (thm1, thm2) => thm1 RSN (1, thm2 RSN (2, @{thm obj_sumE_f})))
+      (replicate n asm_rl) OF (replicate n (impI RS allI)) RS @{thm obj_one_pointE};
+
+fun mk_sum_casesN 1 1 = refl
   | mk_sum_casesN _ 1 = @{thm sum.cases(1)}
   | mk_sum_casesN 2 2 = @{thm sum.cases(2)}
-  | mk_sum_casesN n m = trans OF [@{thm sum_case_step(2)}, mk_sum_casesN (n - 1) (m - 1)];
+  | mk_sum_casesN n k = trans OF [@{thm sum_case_step(2)}, mk_sum_casesN (n - 1) (k - 1)];
+
+fun mk_sum_step base step thm =
+  if Thm.eq_thm_prop (thm, refl) then base else trans OF [step, thm];
+
+fun mk_sum_casesN_balanced 1 1 = refl
+  | mk_sum_casesN_balanced n k =
+    Balanced_Tree.access {left = mk_sum_step @{thm sum.cases(1)} @{thm sum_case_step(1)},
+      right = mk_sum_step @{thm sum.cases(2)} @{thm sum_case_step(2)}, init = refl} n k;
 
 fun mk_tactics mid mcomp mcong snat bdco bdinf sbd inbd wpull =
   [mid, mcomp, mcong] @ snat @ [bdco, bdinf] @ sbd @ [inbd, wpull];