src/HOL/Tools/BNF/bnf_fp_util.ML
changeset 63045 c50c764aab10
parent 62907 9ad0bac25a84
child 63796 45c8762353dd
--- a/src/HOL/Tools/BNF/bnf_fp_util.ML	Fri Apr 22 15:34:37 2016 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_util.ML	Thu Apr 14 20:29:42 2016 +0200
@@ -200,7 +200,9 @@
     thm list -> thm list -> thm list
   val derive_xtor_co_recs: BNF_Util.fp_kind -> binding list -> (typ list -> typ list) ->
     (typ list list * typ list) -> BNF_Def.bnf list -> term list -> term list ->
-    thm -> thm list -> thm list -> thm list -> thm list -> local_theory ->
+    thm -> thm list -> thm list -> thm list -> thm list ->
+    (BNF_Comp.absT_info * BNF_Comp.absT_info) option list ->
+    local_theory ->
     (term list * (thm list * thm * thm list * thm list)) * local_theory
 
   val fixpoint_bnf: (binding -> binding) ->
@@ -624,19 +626,65 @@
   #> Syntax.check_term ctxt
   #> singleton (Variable.polymorphic ctxt);
 
-fun mk_xtor_un_fold_xtor_thms fp xtor_un_fold_unique_thm map_id0s =
-  (xtor_un_fold_unique_thm OF
-    map (fn thm => case_fp fp
-      (mk_trans @{thm id_o}
-        (mk_sym (thm RS @{thm trans[OF arg_cong2[of _ _ _ _ "op o", OF refl] o_id]})))
-      (mk_trans (thm RS @{thm arg_cong2[of _ _ _ _ "op o", OF _ refl]})
-        @{thm trans[OF id_o o_id[symmetric]]}))
-    map_id0s)
-  |> split_conj_thm |> map mk_sym;
+fun absT_info_encodeT thy (SOME (src : absT_info, dst : absT_info)) src_absT =
+    let
+      val src_repT = mk_repT (#absT src) (#repT src) src_absT;
+      val dst_absT = mk_absT thy (#repT dst) (#absT dst) src_repT;
+    in
+      dst_absT
+    end
+  | absT_info_encodeT _ NONE T = T;
+
+fun absT_info_decodeT thy = absT_info_encodeT thy o Option.map swap;
+
+fun absT_info_encode thy fp (opt as SOME (src : absT_info, dst : absT_info)) t =
+    let
+      val co_alg_funT = case_fp fp domain_type range_type;
+      fun co_swap pair = case_fp fp I swap pair;
+      val mk_co_comp = curry (HOLogic.mk_comp o co_swap);
+      val mk_co_abs = case_fp fp mk_abs mk_rep;
+      val mk_co_rep = case_fp fp mk_rep mk_abs;
+      val co_abs = case_fp fp #abs #rep;
+      val co_rep = case_fp fp #rep #abs;
+      val src_absT = co_alg_funT (fastype_of t);
+      val dst_absT = absT_info_encodeT thy opt src_absT;
+      val co_src_abs = mk_co_abs src_absT (co_abs src);
+      val co_dst_rep = mk_co_rep dst_absT (co_rep dst);
+    in
+      mk_co_comp (mk_co_comp t co_src_abs) co_dst_rep
+    end
+  | absT_info_encode _ _ NONE t = t;
+
+fun absT_info_decode thy fp = absT_info_encode thy fp o Option.map swap;
+
+fun mk_xtor_un_fold_xtor_thms ctxt fp un_folds xtors xtor_un_fold_unique map_id0s
+    absT_info_opts =
+  let
+    val thy = Proof_Context.theory_of ctxt;
+    fun mk_goal un_fold =
+      let
+        val rhs = list_comb (un_fold, @{map 2} (absT_info_encode thy fp) absT_info_opts xtors);
+        val T = range_type (fastype_of rhs);
+      in
+        HOLogic.mk_eq (HOLogic.id_const T, rhs)
+      end; 
+    val goal = HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map mk_goal un_folds));
+    fun mk_inverses NONE = []
+      | mk_inverses (SOME (src, dst)) =
+        [#type_definition dst RS @{thm type_definition.Abs_inverse[OF _ UNIV_I]},
+         #type_definition src RS @{thm type_definition.Rep_inverse}];
+    val inverses = maps mk_inverses absT_info_opts;
+  in
+    Goal.prove_sorry ctxt [] [] goal (fn {context = ctxt, prems = _} =>
+      mk_xtor_un_fold_xtor_tac ctxt xtor_un_fold_unique map_id0s inverses)
+    |> split_conj_thm |> map mk_sym
+  end;
 
 fun derive_xtor_co_recs fp bs mk_Ts (Dss, resDs) pre_bnfs xtors0 un_folds0
-    xtor_un_fold_unique xtor_un_folds xtor_un_fold_transfers xtor_maps xtor_rels lthy =
+    xtor_un_fold_unique xtor_un_folds xtor_un_fold_transfers xtor_maps xtor_rels
+    absT_info_opts lthy =
   let
+    val thy = Proof_Context.theory_of lthy;
     fun co_swap pair = case_fp fp I swap pair;
     val mk_co_comp = curry (HOLogic.mk_comp o co_swap);
     fun mk_co_algT T U = case_fp fp (T --> U) (U --> T);
@@ -645,6 +693,7 @@
     val co_proj1_const = case_fp fp fst_const (uncurry Inl_const o dest_sumT) o co_alg_funT;
     val co_proj2_const = case_fp fp snd_const (uncurry Inr_const o dest_sumT) o co_alg_funT;
     val mk_co_productT = curry (case_fp fp HOLogic.mk_prodT mk_sumT);
+    val rewrite_comp_comp = case_fp fp @{thm rewriteL_comp_comp} @{thm rewriteR_comp_comp};
 
     val n = length pre_bnfs;
     val live = live_of_bnf (hd pre_bnfs);
@@ -677,7 +726,8 @@
 
     val TFTs = @{map 2} (fn Ds => mk_T_of_bnf Ds (As @ Ts)) Dss pre_bnfs;
 
-    val xtors = @{map 3} (force_typ names_lthy oo mk_co_algT) TFTs Ts xtors0;
+    val TFTs' = @{map 2} (absT_info_decodeT thy) absT_info_opts TFTs;
+    val xtors = @{map 3} (force_typ names_lthy oo mk_co_algT) TFTs' Ts xtors0;
 
     val ids = map HOLogic.id_const As;
     val co_rec_Xs = @{map 2} mk_co_productT Ts Xs;
@@ -693,12 +743,13 @@
     val (((co_rec_ss, fs), xs), names_lthy) = names_lthy
       |> mk_Frees "s" co_rec_argTs
       ||>> mk_Frees "f" co_rec_resTs
-      ||>> mk_Frees "x" (case_fp fp TFTs Xs);
+      ||>> mk_Frees "x" (case_fp fp TFTs' Xs);
 
     val co_rec_strs =
-      @{map 3} (fn xtor => fn s => fn mapx =>
-        mk_co_product (mk_co_comp xtor (list_comb (mapx, ids @ co_proj1s))) s)
-      xtors co_rec_ss co_rec_maps;
+      @{map 4} (fn xtor => fn s => fn mapx => fn absT_info_opt =>
+        mk_co_product (mk_co_comp (absT_info_encode thy fp absT_info_opt xtor)
+          (list_comb (mapx, ids @ co_proj1s))) s)
+      xtors co_rec_ss co_rec_maps absT_info_opts;
 
     val theta = Xs ~~ co_rec_Xs;
     val co_rec_un_folds = map (subst_atomic_types theta) un_folds;
@@ -730,7 +781,9 @@
     val co_rec_defs = map (fn def =>
       mk_unabs_def n (Morphism.thm phi def RS meta_eq_to_obj_eq)) co_rec_def_frees;
 
-    val xtor_un_fold_xtor_thms = mk_xtor_un_fold_xtor_thms fp xtor_un_fold_unique map_id0s;
+    val xtor_un_fold_xtor_thms =
+      mk_xtor_un_fold_xtor_thms lthy fp (map (Term.subst_atomic_types (Xs ~~ Ts)) un_folds)
+        xtors xtor_un_fold_unique map_id0s absT_info_opts;
 
     val co_rec_id_thms =
       let
@@ -741,8 +794,8 @@
         Goal.prove_sorry lthy vars [] goal
           (fn {context = ctxt, prems = _} => mk_xtor_co_rec_id_tac ctxt xtor_un_fold_xtor_thms
             xtor_un_fold_unique xtor_un_folds map_comps)
-          |> Thm.close_derivation
-          |> split_conj_thm
+        |> Thm.close_derivation
+        |> split_conj_thm
       end;
 
     val co_rec_app_ss = map (fn co_rec => list_comb (co_rec, co_rec_ss)) co_recs;
@@ -754,14 +807,16 @@
       case_fp fp @{thm convol_expand_snd} @{thm case_sum_expand_Inr_pointfree}) co_rec_id_thms;
     val xtor_co_rec_thms =
       let
-        fun mk_goal co_rec s mapx xtor x =
+        fun mk_goal co_rec s mapx xtor x absT_info_opt =
           let
             val lhs = mk_co_app co_rec xtor x;
-            val rhs = mk_co_app s (list_comb (mapx, ids @ co_products)) x;
+            val rhs = mk_co_app s
+              (list_comb (mapx, ids @ co_products) |> absT_info_decode thy fp absT_info_opt) x;
           in
             mk_Trueprop_eq (lhs, rhs)
           end;
-        val goals = @{map 5} mk_goal co_rec_app_ss co_rec_ss co_rec_maps_rev xtors xs;
+        val goals =
+          @{map 6} mk_goal co_rec_app_ss co_rec_ss co_rec_maps_rev xtors xs absT_info_opts;
       in
         map2 (fn goal => fn un_fold =>
           Variable.add_free_names lthy goal []
@@ -777,28 +832,38 @@
       thm RS case_fp fp @{thm convol_expand_snd'} @{thm case_sum_expand_Inr'}) co_rec_id_thms;
     val xtor_co_rec_unique_thm =
       let
-        fun mk_prem f s mapx xtor =
+        fun mk_prem f s mapx xtor absT_info_opt =
           let
             val lhs = mk_co_comp f xtor;
-            val rhs = mk_co_comp s (list_comb (mapx, ids @ co_product_fs));
+            val rhs = mk_co_comp s (list_comb (mapx, ids @ co_product_fs))
+              |> absT_info_decode thy fp absT_info_opt;
           in
             mk_Trueprop_eq (co_swap (lhs, rhs))
           end;
-        val prems = @{map 4} mk_prem fs co_rec_ss co_rec_maps_rev xtors;
+        val prems = @{map 5} mk_prem fs co_rec_ss co_rec_maps_rev xtors absT_info_opts;
         val concl = @{map 2} (curry HOLogic.mk_eq) fs co_rec_app_ss
           |> Library.foldr1 HOLogic.mk_conj |> HOLogic.mk_Trueprop;
         val goal = Logic.list_implies (prems, concl);
         val vars = Variable.add_free_names lthy goal [];
+        fun mk_inverses NONE = []
+          | mk_inverses (SOME (src, dst)) =
+            [#type_definition dst RS @{thm type_copy_Rep_o_Abs} RS rewrite_comp_comp,
+             #type_definition src RS @{thm type_copy_Abs_o_Rep}];
+        val inverses = maps mk_inverses absT_info_opts;
       in
         Goal.prove_sorry lthy vars [] goal
           (fn {context = ctxt, prems = _} => mk_xtor_co_rec_unique_tac ctxt fp co_rec_defs
-            co_rec_expand'_thms xtor_un_fold_unique map_id0s sym_map_comp0s)
+            co_rec_expand'_thms xtor_un_fold_unique map_id0s sym_map_comp0s inverses)
         |> Thm.close_derivation
       end;
 
-    val xtor_co_rec_o_map_thms = mk_xtor_co_iter_o_map_thms fp true m xtor_co_rec_unique_thm
-      (map (mk_pointfree lthy) xtor_maps) (map (mk_pointfree lthy) xtor_co_rec_thms)
-      sym_map_comp0s map_cong0s;
+    val xtor_co_rec_o_map_thms = if forall is_none absT_info_opts
+      then
+        mk_xtor_co_iter_o_map_thms fp true m xtor_co_rec_unique_thm
+          (map (mk_pointfree lthy) xtor_maps) (map (mk_pointfree lthy) xtor_co_rec_thms)
+          sym_map_comp0s map_cong0s
+      else
+        replicate n refl (* FIXME *);
 
     val ABphiTs = @{map 2} mk_pred2T As Bs;
     val XYphiTs = @{map 2} mk_pred2T Xs Ys;
@@ -807,24 +872,29 @@
       |> mk_Frees "R" ABphiTs
       ||>> mk_Frees "S" XYphiTs;
 
-    val pre_rels =
-      @{map 2} (fn Ds => mk_rel_of_bnf Ds (As @ co_rec_Xs) (Bs @ co_rec_Ys)) Dss pre_bnfs;
-    val rels = @{map 3} (fn T => fn T' => Thm.prop_of #> HOLogic.dest_Trueprop
-        #> fst o dest_comb #> fst o dest_comb #> funpow n (snd o dest_comb)
-        #> case_fp fp (fst o dest_comb #> snd o dest_comb) (snd o dest_comb) #> head_of
-        #> force_typ names_lthy (ABphiTs ---> mk_pred2T T T'))
-      Ts Us xtor_un_fold_transfers;
-
-    fun tac {context = ctxt, prems = _} = mk_xtor_co_rec_transfer_tac ctxt fp n m co_rec_defs
-      xtor_un_fold_transfers map_transfers xtor_rels;
-
-    val mk_rel_co_product = case_fp fp mk_rel_prod mk_rel_sum;
-    val rec_phis =
-      map2 (fn rel => mk_rel_co_product (Term.list_comb (rel, ABphis))) rels XYphis;
-
-    val xtor_co_rec_transfer_thms =
-      mk_xtor_co_iter_transfer_thms fp pre_rels rec_phis XYphis rels ABphis
-        co_recs (map (subst_atomic_types (ABs @ XYs)) co_recs) tac lthy;
+    val xtor_co_rec_transfer_thms = if forall is_none absT_info_opts
+      then
+        let
+          val pre_rels =
+            @{map 2} (fn Ds => mk_rel_of_bnf Ds (As @ co_rec_Xs) (Bs @ co_rec_Ys)) Dss pre_bnfs;
+          val rels = @{map 3} (fn T => fn T' => Thm.prop_of #> HOLogic.dest_Trueprop
+              #> fst o dest_comb #> fst o dest_comb #> funpow n (snd o dest_comb)
+              #> case_fp fp (fst o dest_comb #> snd o dest_comb) (snd o dest_comb) #> head_of
+              #> force_typ names_lthy (ABphiTs ---> mk_pred2T T T'))
+            Ts Us xtor_un_fold_transfers;
+      
+          fun tac {context = ctxt, prems = _} = mk_xtor_co_rec_transfer_tac ctxt fp n m co_rec_defs
+            xtor_un_fold_transfers map_transfers xtor_rels;
+      
+          val mk_rel_co_product = case_fp fp mk_rel_prod mk_rel_sum;
+          val rec_phis =
+            map2 (fn rel => mk_rel_co_product (Term.list_comb (rel, ABphis))) rels XYphis;
+        in
+          mk_xtor_co_iter_transfer_thms fp pre_rels rec_phis XYphis rels ABphis
+            co_recs (map (subst_atomic_types (ABs @ XYs)) co_recs) tac lthy
+        end
+      else
+        replicate n TrueI (* FIXME *);
 
     val notes =
       [(case_fp fp ctor_recN dtor_corecN, xtor_co_rec_thms),