src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49124 968e1b7de057
parent 49123 263b0e330d8b
child 49125 5fc5211cf104
equal deleted inserted replaced
49123:263b0e330d8b 49124:968e1b7de057
    17 open BNF_FP_Util
    17 open BNF_FP_Util
    18 open BNF_LFP
    18 open BNF_LFP
    19 open BNF_GFP
    19 open BNF_GFP
    20 open BNF_FP_Sugar_Tactics
    20 open BNF_FP_Sugar_Tactics
    21 
    21 
    22 fun cannot_merge_types () = error "Mutually recursive (co)datatypes must have same type parameters";
    22 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    23 
    23 
    24 fun merge_type_arg_constrained ctxt (T, c) (T', c') =
    24 fun merge_type_arg_constrained ctxt (T, c) (T', c') =
    25   if T = T' then
    25   if T = T' then
    26     (case (c, c') of
    26     (case (c, c') of
    27       (_, NONE) => (T, c)
    27       (_, NONE) => (T, c)
   105     val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss;
   105     val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss;
   106     val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs;
   106     val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs;
   107 
   107 
   108     val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
   108     val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
   109 
   109 
   110     val (raw_flds, lthy') = fp_bnf construct bs eqs lthy;
   110     val ((raw_flds, raw_unfs, fld_unfs, unf_flds), lthy') = fp_bnf construct bs eqs lthy;
   111 
   111 
   112     fun mk_fld Ts fld =
   112     fun mk_fld_or_unf get_foldedT Ts t =
   113       let val Type (_, Ts0) = body_type (fastype_of fld) in
   113       let val Type (_, Ts0) = get_foldedT (fastype_of t) in
   114         Term.subst_atomic_types (Ts0 ~~ Ts) fld
   114         Term.subst_atomic_types (Ts0 ~~ Ts) t
   115       end;
   115       end;
   116 
   116 
       
   117     val mk_fld = mk_fld_or_unf range_type;
       
   118     val mk_unf = mk_fld_or_unf domain_type;
       
   119 
   117     val flds = map (mk_fld As) raw_flds;
   120     val flds = map (mk_fld As) raw_flds;
   118 
   121     val unfs = map (mk_unf As) raw_unfs;
   119     fun wrap_type (((((T, fld), ctr_names), ctr_Tss), disc_names), sel_namess) no_defs_lthy =
   122 
       
   123     fun wrap_type ((((((((T, fld), unf), fld_unf), unf_fld), ctr_names), ctr_Tss), disc_names),
       
   124         sel_namess) no_defs_lthy =
   120       let
   125       let
   121         val n = length ctr_names;
   126         val n = length ctr_names;
   122         val ks = 1 upto n;
   127         val ks = 1 upto n;
   123         val ms = map length ctr_Tss;
   128         val ms = map length ctr_Tss;
   124 
   129 
       
   130         val unf_T = domain_type (fastype_of fld);
       
   131 
   125         val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   132         val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   126 
   133 
   127         val (xss, _) = lthy |> mk_Freess "x" ctr_Tss;
   134         val (((u, v), xss), _) =
       
   135           lthy
       
   136           |> yield_singleton (mk_Frees "u") unf_T
       
   137           ||>> yield_singleton (mk_Frees "v") T
       
   138           ||>> mk_Freess "x" ctr_Tss;
   128 
   139 
   129         val rhss =
   140         val rhss =
   130           map2 (fn k => fn xs =>
   141           map2 (fn k => fn xs =>
   131             fold_rev Term.lambda xs (fld $ mk_InN prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
   142             fold_rev Term.lambda xs (fld $ mk_InN prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
   132 
   143 
   146         val ctrs = map (Morphism.term phi) raw_ctrs;
   157         val ctrs = map (Morphism.term phi) raw_ctrs;
   147         val caseof = Morphism.term phi raw_caseof;
   158         val caseof = Morphism.term phi raw_caseof;
   148 
   159 
   149         val fld_iff_unf_thm =
   160         val fld_iff_unf_thm =
   150           let
   161           let
   151             val fld = @{term "undefined::'a=>'b"};
   162             val goal =
   152             val unf = @{term True};
   163               fold_rev Logic.all [u, v]
   153             val (T, T') = dest_funT (fastype_of fld);
   164                 (mk_Trueprop_eq (HOLogic.mk_eq (v, fld $ u), HOLogic.mk_eq (unf $ v, u)));
   154             val fld_unf = TrueI;
       
   155             val unf_fld = TrueI;
       
   156             val goal = @{term True};
       
   157           in
   165           in
   158             Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   166             Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   159               mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [T, T']) (certify lthy fld)
   167               mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, T]) (certify lthy fld)
   160                 (certify lthy unf) fld_unf unf_fld)
   168                 (certify lthy unf) fld_unf unf_fld)
   161           end;
   169           end;
   162 
   170 
   163         (* ### *)
   171         (* ### *)
   164         fun cheat_tac {context = ctxt, ...} = Skip_Proof.cheat_tac (Proof_Context.theory_of ctxt);
   172         fun cheat_tac {context = ctxt, ...} = Skip_Proof.cheat_tac (Proof_Context.theory_of ctxt);
   174         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   182         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   175       in
   183       in
   176         wrap_data tacss ((ctrs, caseof), (disc_names, sel_namess)) lthy'
   184         wrap_data tacss ((ctrs, caseof), (disc_names, sel_namess)) lthy'
   177       end;
   185       end;
   178   in
   186   in
   179     lthy' |> fold wrap_type (Ts ~~ flds ~~ ctr_namess ~~ ctr_Tsss ~~ disc_namess ~~ sel_namesss)
   187     lthy'
       
   188     |> fold wrap_type (Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ ctr_namess ~~ ctr_Tsss ~~
       
   189       disc_namess ~~ sel_namesss)
   180   end;
   190   end;
   181 
   191 
   182 fun data_cmd info specs lthy =
   192 fun data_cmd info specs lthy =
   183   let
   193   let
   184     val fake_lthy =
   194     val fake_lthy =