src/HOL/Tools/BNF/bnf_fp_util.ML
changeset 62905 52c5a25e0c96
parent 62863 e0b894bba6ff
child 62907 9ad0bac25a84
--- a/src/HOL/Tools/BNF/bnf_fp_util.ML	Thu Apr 07 17:26:22 2016 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_util.ML	Thu Apr 07 17:56:22 2016 +0200
@@ -188,6 +188,8 @@
   val mk_sum_Cinfinite: thm list -> thm
   val mk_sum_card_order: thm list -> thm
 
+  val force_typ: Proof.context -> typ -> term -> term
+
   val mk_xtor_rel_co_induct_thm: BNF_Util.fp_kind -> term list -> term list -> term list ->
     term list -> term list -> term list -> term list -> term list ->
     ({prems: thm list, context: Proof.context} -> tactic) -> Proof.context -> thm
@@ -196,6 +198,10 @@
     ({prems: thm list, context: Proof.context} -> tactic) -> Proof.context -> thm list
   val mk_xtor_co_iter_o_map_thms: BNF_Util.fp_kind -> bool -> int -> thm -> thm list -> thm list ->
     thm list -> thm list -> thm list
+  val derive_xtor_co_recs: BNF_Util.fp_kind -> binding list -> (typ list -> typ list) ->
+    (typ list list * typ list) -> BNF_Def.bnf list -> term list -> term list ->
+    thm -> thm list -> thm list -> thm list -> thm list -> local_theory ->
+    (term list * (thm list * thm * thm list * thm list)) * local_theory
 
   val fixpoint_bnf: (binding -> binding) ->
       (binding list -> (string * sort) list -> typ list * typ list list -> BNF_Def.bnf list ->
@@ -213,6 +219,7 @@
 open BNF_Comp
 open BNF_Def
 open BNF_Util
+open BNF_FP_Util_Tactics
 
 type fp_result =
   {Ts: typ list,
@@ -611,6 +618,232 @@
     split_conj_thm (un_fold_unique OF map (case_fp fp I mk_sym) unique_prems)
   end;
 
+fun force_typ ctxt T =
+  Term.map_types Type_Infer.paramify_vars
+  #> Type.constraint T
+  #> Syntax.check_term ctxt
+  #> singleton (Variable.polymorphic ctxt);
+
+fun mk_xtor_un_fold_xtor_thms fp xtor_un_fold_unique_thm map_id0s =
+  (xtor_un_fold_unique_thm OF
+    map (fn thm => case_fp fp
+      (mk_trans @{thm id_o}
+        (mk_sym (thm RS @{thm trans[OF arg_cong2[of _ _ _ _ "op o", OF refl] o_id]})))
+      (mk_trans (thm RS @{thm arg_cong2[of _ _ _ _ "op o", OF _ refl]})
+        @{thm trans[OF id_o o_id[symmetric]]}))
+    map_id0s)
+  |> split_conj_thm |> map mk_sym;
+
+fun derive_xtor_co_recs fp bs mk_Ts (Dss, resDs) pre_bnfs xtors0 un_folds0
+    xtor_un_fold_unique xtor_un_folds xtor_un_fold_transfers xtor_maps xtor_rels lthy =
+  let
+    fun co_swap pair = case_fp fp I swap pair;
+    val mk_co_comp = curry (HOLogic.mk_comp o co_swap);
+    fun mk_co_algT T U = case_fp fp (T --> U) (U --> T);
+    val co_alg_funT = case_fp fp domain_type range_type;
+    val mk_co_product = curry (case_fp fp mk_convol mk_case_sum);
+    val co_proj1_const = case_fp fp fst_const (uncurry Inl_const o dest_sumT) o co_alg_funT;
+    val co_proj2_const = case_fp fp snd_const (uncurry Inr_const o dest_sumT) o co_alg_funT;
+    val mk_co_productT = curry (case_fp fp HOLogic.mk_prodT mk_sumT);
+
+    val n = length pre_bnfs;
+    val live = live_of_bnf (hd pre_bnfs);
+    val m = live - n;
+    val ks = 1 upto n;
+
+    val map_id0s = map map_id0_of_bnf pre_bnfs;
+    val map_comps = map map_comp_of_bnf pre_bnfs;
+    val map_cong0s = map map_cong0_of_bnf pre_bnfs;
+    val map_transfers = map map_transfer_of_bnf pre_bnfs;
+    val sym_map_comp0s = map (mk_sym o map_comp0_of_bnf) pre_bnfs;
+
+    val deads = fold (union (op =)) Dss resDs;
+    val ((((As, Bs), Xs), Ys), names_lthy) = lthy
+      |> fold Variable.declare_typ deads
+      |> mk_TFrees m
+      ||>> mk_TFrees m
+      ||>> mk_TFrees n
+      ||>> mk_TFrees n;
+
+    val XFTs = @{map 2} (fn Ds => mk_T_of_bnf Ds (As @ Xs)) Dss pre_bnfs;
+    val co_algXFTs = @{map 2} mk_co_algT XFTs Xs;
+    val Ts = mk_Ts As;
+    val un_foldTs = @{map 2} (fn T => fn X => co_algXFTs ---> mk_co_algT T X) Ts Xs;
+    val un_folds = @{map 2} (force_typ names_lthy) un_foldTs un_folds0;
+    val ABs = As ~~ Bs;
+    val XYs = Xs ~~ Ys;
+
+    val Us = map (typ_subst_atomic ABs) Ts;
+
+    val TFTs = @{map 2} (fn Ds => mk_T_of_bnf Ds (As @ Ts)) Dss pre_bnfs;
+
+    val xtors = @{map 3} (force_typ names_lthy oo mk_co_algT) TFTs Ts xtors0;
+
+    val ids = map HOLogic.id_const As;
+    val co_rec_Xs = @{map 2} mk_co_productT Ts Xs;
+    val co_rec_Ys = @{map 2} mk_co_productT Us Ys;
+    val co_rec_algXs = @{map 2} mk_co_algT co_rec_Xs Xs;
+    val co_proj1s = map co_proj1_const co_rec_algXs;
+    val co_rec_maps = @{map 2} (fn Ds =>
+      mk_map_of_bnf Ds (As @ case_fp fp co_rec_Xs Ts) (As @ case_fp fp Ts co_rec_Xs)) Dss pre_bnfs;
+    val co_rec_Ts = @{map 2} (fn Ds => mk_T_of_bnf Ds (As @ co_rec_Xs)) Dss pre_bnfs
+    val co_rec_argTs = @{map 2} mk_co_algT co_rec_Ts Xs;
+    val co_rec_resTs = @{map 2} mk_co_algT Ts Xs;
+
+    val (((co_rec_ss, fs), xs), names_lthy) = names_lthy
+      |> mk_Frees "s" co_rec_argTs
+      ||>> mk_Frees "f" co_rec_resTs
+      ||>> mk_Frees "x" (case_fp fp TFTs Xs);
+
+    val co_rec_strs =
+      @{map 3} (fn xtor => fn s => fn mapx =>
+        mk_co_product (mk_co_comp xtor (list_comb (mapx, ids @ co_proj1s))) s)
+      xtors co_rec_ss co_rec_maps;
+
+    val theta = Xs ~~ co_rec_Xs;
+    val co_rec_un_folds = map (subst_atomic_types theta) un_folds;
+
+    val co_rec_spec0s = map (fn un_fold => list_comb (un_fold, co_rec_strs)) co_rec_un_folds;
+
+    val co_rec_ids = @{map 2} (mk_co_comp o co_proj1_const) co_rec_algXs co_rec_spec0s;
+    val co_rec_specs = @{map 2} (mk_co_comp o co_proj2_const) co_rec_algXs co_rec_spec0s;
+
+    val co_recN = case_fp fp ctor_recN dtor_corecN;
+    fun co_rec_bind i = nth bs (i - 1) |> Binding.prefix_name (co_recN ^ "_");
+    val co_rec_def_bind = rpair [] o Binding.concealed o Thm.def_binding o co_rec_bind;
+
+    fun co_rec_spec i =
+      fold_rev (Term.absfree o Term.dest_Free) co_rec_ss (nth co_rec_specs (i - 1));
+
+    val ((co_rec_frees, (_, co_rec_def_frees)), (lthy, lthy_old)) =
+      lthy
+      |> Local_Theory.open_target |> snd
+      |> fold_map (fn i =>
+        Local_Theory.define ((co_rec_bind i, NoSyn), (co_rec_def_bind i, co_rec_spec i))) ks
+      |>> apsnd split_list o split_list
+      ||> `Local_Theory.close_target;
+
+    val phi = Proof_Context.export_morphism lthy_old lthy;
+    val co_rec_names = map (fst o dest_Const o Morphism.term phi) co_rec_frees;
+    val co_recs = @{map 2} (fn name => fn resT =>
+      Const (name, co_rec_argTs ---> resT)) co_rec_names co_rec_resTs;
+    val co_rec_defs = map (fn def =>
+      mk_unabs_def n (Morphism.thm phi def RS meta_eq_to_obj_eq)) co_rec_def_frees;
+
+    val xtor_un_fold_xtor_thms = mk_xtor_un_fold_xtor_thms fp xtor_un_fold_unique map_id0s;
+
+    val co_rec_id_thms =
+      let
+        val goal = @{map 2} (fn T => fn t => HOLogic.mk_eq (t, HOLogic.id_const T)) Ts co_rec_ids
+          |> Library.foldr1 HOLogic.mk_conj |> HOLogic.mk_Trueprop;
+        val vars = Variable.add_free_names lthy goal [];
+      in
+        Goal.prove_sorry lthy vars [] goal
+          (fn {context = ctxt, prems = _} => mk_xtor_co_rec_id_tac ctxt xtor_un_fold_xtor_thms
+            xtor_un_fold_unique xtor_un_folds map_comps)
+          |> Thm.close_derivation
+          |> split_conj_thm
+      end;
+
+    val co_rec_app_ss = map (fn co_rec => list_comb (co_rec, co_rec_ss)) co_recs;
+    val co_products = @{map 2} (fn T => mk_co_product (HOLogic.id_const T)) Ts co_rec_app_ss;
+    val co_rec_maps_rev = @{map 2} (fn Ds =>
+      mk_map_of_bnf Ds (As @ case_fp fp Ts co_rec_Xs) (As @ case_fp fp co_rec_Xs Ts)) Dss pre_bnfs;
+    fun mk_co_app f g x = case_fp fp (f $ (g $ x)) (g $ (f $ x));
+    val co_rec_expand_thms = map (fn thm => thm RS
+      case_fp fp @{thm convol_expand_snd} @{thm case_sum_expand_Inr_pointfree}) co_rec_id_thms;
+    val xtor_co_rec_thms =
+      let
+        fun mk_goal co_rec s mapx xtor x =
+          let
+            val lhs = mk_co_app co_rec xtor x;
+            val rhs = mk_co_app s (list_comb (mapx, ids @ co_products)) x;
+          in
+            mk_Trueprop_eq (lhs, rhs)
+          end;
+        val goals = @{map 5} mk_goal co_rec_app_ss co_rec_ss co_rec_maps_rev xtors xs;
+      in
+        map2 (fn goal => fn un_fold =>
+          Variable.add_free_names lthy goal []
+          |> (fn vars => Goal.prove_sorry lthy vars [] goal
+            (fn {context = ctxt, prems = _} =>
+              mk_xtor_co_rec_tac ctxt un_fold co_rec_defs co_rec_expand_thms))
+          |> Thm.close_derivation)
+        goals xtor_un_folds
+      end;
+
+    val co_product_fs = @{map 2} (fn T => mk_co_product (HOLogic.id_const T)) Ts fs;
+    val co_rec_expand'_thms = map (fn thm =>
+      thm RS case_fp fp @{thm convol_expand_snd'} @{thm case_sum_expand_Inr'}) co_rec_id_thms;
+    val xtor_co_rec_unique_thm =
+      let
+        fun mk_prem f s mapx xtor =
+          let
+            val lhs = mk_co_comp f xtor;
+            val rhs = mk_co_comp s (list_comb (mapx, ids @ co_product_fs));
+          in
+            mk_Trueprop_eq (co_swap (lhs, rhs))
+          end;
+        val prems = @{map 4} mk_prem fs co_rec_ss co_rec_maps_rev xtors;
+        val concl = @{map 2} (curry HOLogic.mk_eq) fs co_rec_app_ss
+          |> Library.foldr1 HOLogic.mk_conj |> HOLogic.mk_Trueprop;
+        val goal = Logic.list_implies (prems, concl);
+        val vars = Variable.add_free_names lthy goal [];
+      in
+        Goal.prove_sorry lthy vars [] goal
+          (fn {context = ctxt, prems = _} => mk_xtor_co_rec_unique_tac ctxt fp co_rec_defs
+            co_rec_expand'_thms xtor_un_fold_unique map_id0s sym_map_comp0s)
+        |> Thm.close_derivation
+      end;
+
+    val xtor_co_rec_o_map_thms = mk_xtor_co_iter_o_map_thms fp true m xtor_co_rec_unique_thm
+      (map (mk_pointfree lthy) xtor_maps) (map (mk_pointfree lthy) xtor_co_rec_thms)
+      sym_map_comp0s map_cong0s;
+
+    val ABphiTs = @{map 2} mk_pred2T As Bs;
+    val XYphiTs = @{map 2} mk_pred2T Xs Ys;
+
+    val ((ABphis, XYphis), names_lthy) = names_lthy
+      |> mk_Frees "R" ABphiTs
+      ||>> mk_Frees "S" XYphiTs;
+
+    val pre_rels =
+      @{map 2} (fn Ds => mk_rel_of_bnf Ds (As @ co_rec_Xs) (Bs @ co_rec_Ys)) Dss pre_bnfs;
+    val rels = @{map 3} (fn T => fn T' => Thm.prop_of #> HOLogic.dest_Trueprop
+        #> fst o dest_comb #> fst o dest_comb #> funpow n (snd o dest_comb)
+        #> case_fp fp (fst o dest_comb #> snd o dest_comb) (snd o dest_comb) #> head_of
+        #> force_typ names_lthy (ABphiTs ---> mk_pred2T T T'))
+      Ts Us xtor_un_fold_transfers;
+
+    fun tac {context = ctxt, prems = _} = mk_xtor_co_rec_transfer_tac ctxt fp n m co_rec_defs
+      xtor_un_fold_transfers map_transfers xtor_rels;
+
+    val mk_rel_co_product = case_fp fp mk_rel_prod mk_rel_sum;
+    val rec_phis =
+      map2 (fn rel => mk_rel_co_product (Term.list_comb (rel, ABphis))) rels XYphis;
+
+    val xtor_co_rec_transfer_thms =
+      mk_xtor_co_iter_transfer_thms fp pre_rels rec_phis XYphis rels ABphis
+        co_recs (map (subst_atomic_types (ABs @ XYs)) co_recs) tac lthy;
+
+    val notes =
+      [(case_fp fp ctor_recN dtor_corecN, xtor_co_rec_thms),
+       (case_fp fp ctor_rec_uniqueN dtor_corec_uniqueN, split_conj_thm xtor_co_rec_unique_thm),
+       (case_fp fp ctor_rec_o_mapN dtor_corec_o_mapN, xtor_co_rec_o_map_thms),
+       (case_fp fp ctor_rec_transferN dtor_corec_transferN, xtor_co_rec_transfer_thms)]
+      |> map (apsnd (map single))
+      |> maps (fn (thmN, thmss) =>
+        map2 (fn b => fn thms =>
+          ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
+        bs thmss);
+
+     val lthy = lthy |> Config.get lthy bnf_internals ? snd o Local_Theory.notes notes;
+  in
+    ((co_recs,
+     (xtor_co_rec_thms, xtor_co_rec_unique_thm, xtor_co_rec_o_map_thms, xtor_co_rec_transfer_thms)),
+      lthy)
+  end;
+
 fun fixpoint_bnf extra_qualify construct_fp bs resBs Ds0 fp_eqs comp_cache0 lthy =
   let
     val time = time lthy;