src/HOL/Tools/BNF/bnf_fp_util.ML
changeset 55803 74d3fe9031d8
parent 55706 064c7c249f55
child 55811 aa1acc25126b
--- a/src/HOL/Tools/BNF/bnf_fp_util.ML	Fri Feb 28 12:04:40 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_util.ML	Tue Feb 25 18:14:26 2014 +0100
@@ -142,16 +142,15 @@
 
   val mk_case_sum: term * term -> term
   val mk_case_sumN: term list -> term
-  val mk_case_sumN_balanced: term list -> term
+  val mk_case_absumprod: typ -> term -> term list -> term list -> term list list -> term
+
   val mk_Inl: typ -> term -> term
   val mk_Inr: typ -> term -> term
   val mk_InN: typ list -> term -> int -> term
-  val mk_InN_balanced: typ -> int -> term -> int -> term
+  val mk_absumprod: typ -> term -> int -> int -> term list -> term
 
   val dest_sumT: typ -> typ * typ
-  val dest_sumTN: int -> typ -> typ list
-  val dest_sumTN_balanced: int -> typ -> typ list
-  val dest_tupleT: int -> typ -> typ list
+  val dest_absumprodT: typ -> typ -> int -> int list -> typ -> typ list list
 
   val If_const: typ -> term
 
@@ -160,8 +159,8 @@
   val mk_union: term * term -> term
 
   val mk_sumEN: int -> thm
-  val mk_sumEN_balanced: int -> thm
-  val mk_sumEN_tupled_balanced: int list -> thm
+  val mk_absumprodE: thm -> int list -> thm
+
   val mk_sum_caseN: int -> int -> thm
   val mk_sum_caseN_balanced: int -> int -> thm
 
@@ -176,12 +175,12 @@
   val mk_xtor_un_fold_o_map_thms: BNF_Util.fp_kind -> bool -> int -> thm -> thm list -> thm list ->
     thm list -> thm list -> thm list
 
-  val mk_strong_coinduct_thm: thm -> thm list -> thm list -> Proof.context -> thm
+  val mk_strong_coinduct_thm: thm -> thm list -> thm list -> (thm -> thm) -> Proof.context -> thm
 
   val fp_bnf: (binding list -> (string * sort) list -> typ list * typ list list ->
-      BNF_Def.bnf list -> local_theory -> 'a) ->
+      BNF_Def.bnf list -> BNF_Comp.absT_info list -> local_theory -> 'a) ->
     binding list -> (string * sort) list -> (string * sort) list -> ((string * sort) * typ) list ->
-    local_theory -> BNF_Def.bnf list * 'a
+    local_theory -> (BNF_Def.bnf list * BNF_Comp.absT_info list) * 'a
 end;
 
 structure BNF_FP_Util : BNF_FP_UTIL =
@@ -347,9 +346,6 @@
 
 fun dest_sumT (Type (@{type_name sum}, [T, T'])) = (T, T');
 
-fun dest_sumTN 1 T = [T]
-  | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T';
-
 val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
 
 (* TODO: move something like this to "HOLogic"? *)
@@ -357,6 +353,8 @@
   | dest_tupleT 1 T = [T]
   | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T';
 
+fun dest_absumprodT absT repT n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o mk_repT absT repT;
+
 val mk_sumTN = Library.foldr1 mk_sumT;
 val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
 
@@ -397,6 +395,10 @@
     |> repair_types sum_T
   end;
 
+fun mk_absumprod absT abs0 n k ts =
+  let val abs = mk_abs absT abs0;
+  in abs $ mk_InN_balanced (domain_type (fastype_of abs)) n (HOLogic.mk_tuple ts) k end;
+
 fun mk_case_sum (f, g) =
   let
     val fT = fastype_of f;
@@ -409,6 +411,12 @@
 val mk_case_sumN = Library.foldr1 mk_case_sum;
 val mk_case_sumN_balanced = Balanced_Tree.make mk_case_sum;
 
+fun mk_tupled_fun f x xs =
+  if xs = [x] then f else HOLogic.tupled_lambda x (Term.list_comb (f, xs));
+
+fun mk_case_absumprod absT rep fs xs xss =
+  HOLogic.mk_comp (mk_case_sumN_balanced (map3 mk_tupled_fun fs xs xss), mk_rep absT rep);
+
 fun If_const T = Const (@{const_name If}, HOLogic.boolT --> T --> T --> T);
 fun mk_If p t f = let val T = fastype_of t in If_const T $ p $ t $ f end;
 
@@ -441,21 +449,15 @@
   Balanced_Tree.make (fn (thm1, thm2) => thm1 RSN (1, thm2 RSN (2, @{thm obj_sumE_f})))
     (replicate n asm_rl);
 
-fun mk_sumEN_balanced' n all_impIs = mk_obj_sumEN_balanced n OF all_impIs RS @{thm obj_one_pointE};
-
-fun mk_sumEN_balanced 1 = @{thm one_pointE} (*optimization*)
-  | mk_sumEN_balanced 2 = @{thm sumE} (*optimization*)
-  | mk_sumEN_balanced n = mk_sumEN_balanced' n (replicate n (impI RS allI));
-
 fun mk_tupled_allIN 0 = @{thm unit_all_impI}
   | mk_tupled_allIN 1 = @{thm impI[THEN allI]}
   | mk_tupled_allIN 2 = @{thm prod_all_impI} (*optimization*)
   | mk_tupled_allIN n = mk_tupled_allIN (n - 1) RS @{thm prod_all_impI_step};
 
-fun mk_sumEN_tupled_balanced ms =
+fun mk_absumprodE type_definition ms =
   let val n = length ms in
-    if forall (curry op = 1) ms then mk_sumEN_balanced n
-    else mk_sumEN_balanced' n (map mk_tupled_allIN ms)
+    mk_obj_sumEN_balanced n OF map mk_tupled_allIN ms RS
+      (type_definition RS @{thm type_copy_obj_one_point_absE})
   end;
 
 fun mk_sum_caseN 1 1 = refl
@@ -543,7 +545,7 @@
     split_conj_thm (un_fold_unique OF map (fp_case fp I mk_sym) unique_prems)
   end;
 
-fun mk_strong_coinduct_thm coind rel_eqs rel_monos ctxt =
+fun mk_strong_coinduct_thm coind rel_eqs rel_monos mk_vimage2p ctxt =
   let
     val n = Thm.nprems_of coind;
     val m = Thm.nprems_of (hd rel_monos) - n;
@@ -554,7 +556,7 @@
       let
         val eq = iffD2 OF [rel_eq RS @{thm predicate2_eqD}, refl];
         val mono = rel_mono OF (replicate m @{thm order_refl} @ replicate n @{thm eq_subset});
-      in eq RS (mono RS @{thm predicate2D}) RS @{thm eqTrueI} end;
+      in mk_vimage2p (eq RS (mono RS @{thm predicate2D})) RS @{thm eqTrueI} end;
     val unfolds = map2 mk_unfold rel_eqs rel_monos @ @{thms sup_fun_def sup_bool_def
       imp_disjL all_conj_distrib subst_eq_imp simp_thms(18,21,35)};
   in
@@ -603,18 +605,19 @@
     fun pre_qualify b = Binding.qualify false (Binding.name_of b)
       #> Config.get lthy' bnf_note_all = false ? Binding.conceal;
 
-    val ((pre_bnfs, deadss), lthy'') =
+    val ((pre_bnfs, (deadss, absT_infos)), lthy'') =
       fold_map3 (fn b => seal_bnf (pre_qualify b) unfold_set' (Binding.prefix_name preN b))
         bs Dss bnfs' lthy'
-      |>> split_list;
+      |>> split_list
+      |>> apsnd split_list;
 
     val timer = time (timer "Normalization & sealing of BNFs");
 
-    val res = construct_fp bs resBs (map TFree resDs, deadss) pre_bnfs lthy'';
+    val res = construct_fp bs resBs (map TFree resDs, deadss) pre_bnfs absT_infos lthy'';
 
     val timer = time (timer "FP construction in total");
   in
-    timer; (pre_bnfs, res)
+    timer; ((pre_bnfs, absT_infos), res)
   end;
 
 end;