src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 52301 7935e82a4ae4
parent 52300 4a4da43e855a
child 52302 867d5d16158c
equal deleted inserted replaced
52300:4a4da43e855a 52301:7935e82a4ae4
    26   val morph_fp_sugar: morphism -> fp_sugar -> fp_sugar
    26   val morph_fp_sugar: morphism -> fp_sugar -> fp_sugar
    27   val fp_sugar_of: Proof.context -> string -> fp_sugar option
    27   val fp_sugar_of: Proof.context -> string -> fp_sugar option
    28 
    28 
    29   val tvar_subst: theory -> typ list -> typ list -> ((string * int) * typ) list
    29   val tvar_subst: theory -> typ list -> typ list -> ((string * int) * typ) list
    30   val exists_subtype_in: typ list -> typ -> bool
    30   val exists_subtype_in: typ list -> typ -> bool
    31   val flat_rec: ('a -> 'b list) -> 'a list -> 'b list
    31   val flat_rec: 'a list list -> 'a list
    32   val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
    32   val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
    33   val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
    33   val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
    34   val indexify_fst: ''a list -> (int -> ''a * 'b -> 'c) -> ''a * 'b -> 'c
    34   val indexify_fst: ''a list -> (int -> ''a * 'b -> 'c) -> ''a * 'b -> 'c
    35   val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
    35   val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
    36     int list list -> term list -> term list -> Proof.context ->
    36     int list list -> term list -> term list -> Proof.context ->
    37     (term list * term list * ((term list list * typ list list * term list list list list)
    37     (term list * term list
    38        * (term list list * typ list list * term list list list list)) option
    38        * ((typ list list * typ list list list list * term list list * term list list list list)
       
    39           * (typ list list * typ list list list list * term list list
       
    40              * term list list list list)) option
    39      * (term list * term list list
    41      * (term list * term list list
    40         * ((term list list * term list list list list * term list list list list)
    42         * ((term list list * term list list list list * term list list list list)
    41            * (typ list * typ list list list * typ list list))
    43            * (typ list * typ list list list * typ list list))
    42         * ((term list list * term list list list list * term list list list list)
    44         * ((term list list * term list list list list * term list list list list)
    43            * (typ list * typ list list list * typ list list))) option)
    45            * (typ list * typ list list list * typ list list))) option)
    45   val mk_map: int -> typ list -> typ list -> term -> term
    47   val mk_map: int -> typ list -> typ list -> term -> term
    46   val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term
    48   val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term
    47 
    49 
    48   val mk_iter_fun_arg_typessss: typ list -> int list -> int list list -> term ->
    50   val mk_iter_fun_arg_typessss: typ list -> int list -> int list list -> term ->
    49     typ list list list list
    51     typ list list list list
    50   val define_fold_rec: (term list list * typ list list * term list list list list)
    52   val define_fold_rec:
    51       * (term list list * typ list list * term list list list list) -> (string -> binding) ->
    53     (typ list list * typ list list list list * term list list * term list list list list)
    52     typ list -> typ list -> term -> term -> Proof.context ->
    54      * (typ list list * typ list list list list * term list list * term list list list list) ->
       
    55     (string -> binding) -> typ list -> typ list -> term -> term -> Proof.context ->
    53     (term * term * thm * thm) * Proof.context
    56     (term * term * thm * thm) * Proof.context
    54   val define_unfold_corec: term list * term list list
    57   val define_unfold_corec: term list * term list list
    55       * ((term list list * term list list list list * term list list list list)
    58       * ((term list list * term list list list list * term list list list list)
    56          * (typ list * typ list list list * typ list list))
    59          * (typ list * typ list list list * typ list list))
    57       * ((term list list * term list list list list * term list list list list)
    60       * ((term list list * term list list list list * term list list list list)
   186     | SOME T' => T')
   189     | SOME T' => T')
   187   | typ_subst_nonatomic inst T = the_default T (AList.lookup (op =) inst T);
   190   | typ_subst_nonatomic inst T = the_default T (AList.lookup (op =) inst T);
   188 
   191 
   189 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
   192 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
   190 
   193 
   191 fun flat_rec unzipf xs =
   194 fun flat_rec xss =
   192   let val ps = map unzipf xs in
   195   (* The first line below gives the preferred order. The second line is for compatibility with the
   193     (* The first line below gives the preferred order. The second line is for compatibility with the
   196      old datatype package: *)
   194        old datatype package: *)
       
   195 (*
   197 (*
   196     flat ps
   198   flat xss
   197 *)
   199 *)
   198     map hd ps @ maps tl ps
   200   map hd xss @ maps tl xss;
   199   end;
       
   200 
   201 
   201 fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
   202 fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
   202 
   203 
   203 fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
   204 fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
   204   | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
   205   | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
   205     p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
   206     p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
   206 
   207 
   207 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
   208 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
   208 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
   209 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
   209 fun mk_uncurried2_fun f xss =
   210 fun mk_uncurried2_fun f xss =
   210   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat_rec I xss);
   211   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat_rec xss);
   211 
   212 
   212 fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
   213 fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
   213   Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
   214   Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
   214 
   215 
   215 fun flip_rels lthy n thm =
   216 fun flip_rels lthy n thm =
   283 
   284 
   284     val ((gss, ysss), lthy) =
   285     val ((gss, ysss), lthy) =
   285       lthy
   286       lthy
   286       |> mk_Freess "f" g_Tss
   287       |> mk_Freess "f" g_Tss
   287       ||>> mk_Freesss "x" y_Tsss;
   288       ||>> mk_Freesss "x" y_Tsss;
       
   289 
       
   290     val y_Tssss = map (map (map single)) y_Tsss;
   288     val yssss = map (map (map single)) ysss;
   291     val yssss = map (map (map single)) ysss;
   289 
   292 
   290     val z_Tssss =
   293     val z_Tssss =
   291       map3 (fn n => fn ms => map2 (map (unzip_recT Cs) oo dest_tupleT) ms o
   294       map3 (fn n => fn ms => map2 (map (unzip_recT Cs) oo dest_tupleT) ms o
   292         dest_sumTN_balanced n o domain_type) ns mss ctor_rec_fun_Ts;
   295         dest_sumTN_balanced n o domain_type) ns mss ctor_rec_fun_Ts;
   293 
   296 
   294     val z_Tsss' = map (map (flat_rec I)) z_Tssss;
   297     val z_Tsss' = map (map flat_rec) z_Tssss;
   295     val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
   298     val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
   296 
   299 
   297     val hss = map2 (map2 retype_free) h_Tss gss;
   300     val hss = map2 (map2 retype_free) h_Tss gss;
   298     val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
   301     val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
   299     val (zssss_tl, lthy) =
   302     val (zssss_tl, lthy) =
   300       lthy
   303       lthy
   301       |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
   304       |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
   302     val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
   305     val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
   303   in
   306   in
   304     (((gss, g_Tss, yssss), (hss, h_Tss, zssss)), lthy)
   307     (((g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)), lthy)
   305   end;
   308   end;
   306 
   309 
   307 fun mk_unfold_corec_args_types Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
   310 fun mk_unfold_corec_args_types Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
   308   let
   311   let
   309     (*avoid "'a itself" arguments in coiterators and corecursors*)
   312     (*avoid "'a itself" arguments in coiterators and corecursors*)
   481 
   484 
   482     val nn = length fpTs;
   485     val nn = length fpTs;
   483 
   486 
   484     val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of ctor_fold));
   487     val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of ctor_fold));
   485 
   488 
   486     fun generate_iter (suf, ctor_iter, (fss, f_Tss, xssss)) =
   489     fun generate_iter (suf, ctor_iter, (f_Tss, _, fss, xssss)) =
   487       let
   490       let
   488         val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
   491         val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
   489         val binding = mk_binding suf;
   492         val binding = mk_binding suf;
   490         val spec =
   493         val spec =
   491           mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binding, res_T)),
   494           mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binding, res_T)),
   567     val fp_b_names = map base_name_of_typ fpTs;
   570     val fp_b_names = map base_name_of_typ fpTs;
   568 
   571 
   569     val ctor_fold_fun_Ts = mk_fp_iter_fun_types (hd ctor_folds);
   572     val ctor_fold_fun_Ts = mk_fp_iter_fun_types (hd ctor_folds);
   570     val ctor_rec_fun_Ts = mk_fp_iter_fun_types (hd ctor_recs);
   573     val ctor_rec_fun_Ts = mk_fp_iter_fun_types (hd ctor_recs);
   571 
   574 
   572     val (((gss, _, _), (hss, _, _)), names_lthy0) =
   575     val (((_, y_Tssss, gss, _), (_, z_Tssss, hss, _)), names_lthy0) =
   573       mk_fold_rec_args_types Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
   576       mk_fold_rec_args_types Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
   574 
   577 
   575     val ((((ps, ps'), xsss), us'), names_lthy) =
   578     val ((((ps, ps'), xsss), us'), names_lthy) =
   576       names_lthy0
   579       names_lthy0
   577       |> mk_Frees' "P" (map mk_pred1T fpTs)
   580       |> mk_Frees' "P" (map mk_pred1T fpTs)
   669 
   672 
   670         fun mk_goal fss fiter xctr f xs fxs =
   673         fun mk_goal fss fiter xctr f xs fxs =
   671           fold_rev (fold_rev Logic.all) (xs :: fss)
   674           fold_rev (fold_rev Logic.all) (xs :: fss)
   672             (mk_Trueprop_eq (fiter $ xctr, Term.list_comb (f, fxs)));
   675             (mk_Trueprop_eq (fiter $ xctr, Term.list_comb (f, fxs)));
   673 
   676 
   674         val mk_U = typ_subst_nonatomic (map2 pair fpTs Cs);
       
   675 
       
   676         fun mk_nested_U maybe_mk_prodT =
       
   677           typ_subst_nonatomic (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
       
   678 
       
   679         fun unzip_iters fiters maybe_tick maybe_mk_prodT x =
       
   680           let val Free (_, T) = x in
       
   681             if member (op =) fpTs T then
       
   682               [x, build_map lthy (indexify_fst fpTs (K o nth fiters)) (T, mk_U T) $ x]
       
   683             else if exists_subtype_in fpTs T then
       
   684               [build_map lthy (indexify_fst fpTs (fn kk => fn _ =>
       
   685                  maybe_tick (nth us kk) (nth fiters kk))) (T, mk_nested_U maybe_mk_prodT T) $ x]
       
   686             else
       
   687               [x]
       
   688           end;
       
   689 
       
   690         fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
   677         fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
   691 
   678 
   692         val gxsss = map (map (flat_rec (single o List.last o unzip_iters gfolds (K I) (K I)))) xsss;
   679         fun unzip_iters fiters maybe_tick (x as Free (_, T)) Us =
   693         val hxsss = map (map (flat_rec (unzip_iters hrecs tick (curry HOLogic.mk_prodT)))) xsss;
   680           map (fn U => if U = T then x else
       
   681             build_map lthy (indexify_fst fpTs (fn kk => fn _ =>
       
   682               nth fiters kk |> length Us = 1 ? maybe_tick (nth us kk))) (T, U) $ x) Us;
       
   683 
       
   684         val gxsss = map2 (map2 (flat_rec oo map2 (unzip_iters gfolds (K I)))) xsss y_Tssss;
       
   685         val hxsss = map2 (map2 (flat_rec oo map2 (unzip_iters hrecs tick))) xsss z_Tssss;
   694 
   686 
   695         val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
   687         val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
   696         val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
   688         val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
   697 
   689 
   698         val fold_tacss =
   690         val fold_tacss =