src/HOL/Tools/BNF/bnf_lfp_size.ML
author blanchet
Wed Apr 23 10:23:26 2014 +0200 (2014-04-23)
changeset 56638 092a306bcc3d
child 56639 c9d6b581bd3b
permissions -rw-r--r--
generate size instances for new-style datatypes
     1 (*  Title:      HOL/Tools/BNF/bnf_lfp_size.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2014
     4 
     5 Generation of size functions for new-style datatypes.
     6 *)
     7 
     8 structure BNF_LFP_Size : sig end =
     9 struct
    10 
    11 open BNF_Util
    12 open BNF_Def
    13 open BNF_FP_Def_Sugar
    14 
    15 val size_N = "size_"
    16 
    17 val sizeN = "size"
    18 val size_mapN = "size_map"
    19 
    20 structure Data = Theory_Data
    21 (
    22   type T = (string * (thm list * thm list)) Symtab.table;
    23   val empty = Symtab.empty;
    24   val extend = I
    25   fun merge data = Symtab.merge (K true) data;
    26 );
    27 
    28 val zero_nat = @{const zero_class.zero (nat)};
    29 
    30 fun mk_plus_nat (t1, t2) = Const (@{const_name Groups.plus},
    31   HOLogic.natT --> HOLogic.natT --> HOLogic.natT) $ t1 $ t2;
    32 
    33 fun mk_to_natT T = T --> HOLogic.natT;
    34 
    35 fun mk_abs_zero_nat T = Term.absdummy T zero_nat;
    36 
    37 fun generate_size (fp_sugars as ({T = Type (_, As), BT = Type (_, Bs),
    38     fp_res = {bnfs = fp_bnfs, ...}, common_co_inducts = common_inducts, ...} : fp_sugar) :: _) thy =
    39   let
    40     val data = Data.get thy;
    41 
    42     val Ts = map #T fp_sugars
    43     val T_names = map (fst o dest_Type) Ts;
    44     val nn = length Ts;
    45 
    46     val B_ify = Term.typ_subst_atomic (As ~~ Bs);
    47 
    48     val recs = map #co_rec fp_sugars;
    49     val rec_thmss = map #co_rec_thms fp_sugars;
    50     val rec_Ts = map fastype_of recs;
    51     val Cs = map body_type rec_Ts;
    52     val Cs_rho = map (rpair HOLogic.natT) Cs;
    53     val substCT = Term.subst_atomic_types Cs_rho;
    54 
    55     val f_Ts = map mk_to_natT As;
    56     val f_TsB = map mk_to_natT Bs;
    57     val num_As = length As;
    58 
    59     val f_names = map (prefix "f" o string_of_int) (1 upto num_As);
    60     val fs = map2 (curry Free) f_names f_Ts;
    61     val fsB = map2 (curry Free) f_names f_TsB;
    62     val As_fs = As ~~ fs;
    63 
    64     val gen_size_names = map (Long_Name.map_base_name (prefix size_N)) T_names;
    65 
    66     fun is_pair_C @{type_name prod} [_, T'] = member (op =) Cs T'
    67       | is_pair_C _ _ = false;
    68 
    69     fun mk_size_of_typ (T as TFree _) =
    70         pair (case AList.lookup (op =) As_fs T of
    71             SOME f => f
    72           | NONE => if member (op =) Cs T then Term.absdummy T (Bound 0) else mk_abs_zero_nat T)
    73       | mk_size_of_typ (T as Type (s, Ts)) =
    74         if is_pair_C s Ts then
    75           pair (snd_const T)
    76         else if exists (exists_subtype_in As) Ts then
    77           (case Symtab.lookup data s of
    78             SOME (gen_size_name, (_, gen_size_maps)) =>
    79             let
    80               val (args, gen_size_mapss') = split_list (map (fn T => mk_size_of_typ T []) Ts);
    81               val gen_size_const = Const (gen_size_name, map fastype_of args ---> mk_to_natT T);
    82             in
    83               fold (union Thm.eq_thm) (gen_size_maps :: gen_size_mapss')
    84               #> pair (Term.list_comb (gen_size_const, args))
    85             end
    86           | NONE => pair (mk_abs_zero_nat T))
    87         else
    88           pair (mk_abs_zero_nat T);
    89 
    90     fun mk_size_of_arg t =
    91       mk_size_of_typ (fastype_of t) #>> (fn s => substCT (betapply (s, t)));
    92 
    93     fun mk_gen_size_arg arg_T gen_size_maps =
    94       let
    95         val x_Ts = binder_types arg_T;
    96         val m = length x_Ts;
    97         val x_names = map (prefix "x" o string_of_int) (1 upto m);
    98         val xs = map2 (curry Free) x_names x_Ts;
    99         val (summands, gen_size_maps') =
   100           fold_map mk_size_of_arg xs gen_size_maps
   101           |>> remove (op =) zero_nat;
   102         val sum =
   103           if null summands then HOLogic.zero
   104           else foldl1 mk_plus_nat (summands @ [HOLogic.Suc_zero]);
   105       in
   106         (fold_rev Term.lambda (map substCT xs) sum, gen_size_maps')
   107       end;
   108 
   109     fun mk_gen_size_rhs rec_T recx gen_size_maps =
   110       let
   111         val arg_Ts = binder_fun_types rec_T;
   112         val (args, gen_size_maps') = fold_map mk_gen_size_arg arg_Ts gen_size_maps;
   113       in
   114         (fold_rev Term.lambda fs (Term.list_comb (substCT recx, args)), gen_size_maps')
   115       end;
   116 
   117     fun mk_def_binding f = Binding.conceal o Binding.name o Thm.def_name o f o Long_Name.base_name;
   118 
   119     val (gen_size_rhss, nested_gen_size_maps) = fold_map2 mk_gen_size_rhs rec_Ts recs [];
   120     val gen_size_Ts = map fastype_of gen_size_rhss;
   121     val gen_size_consts = map2 (curry Const) gen_size_names gen_size_Ts;
   122     val gen_size_constsB = map (Term.map_types B_ify) gen_size_consts;
   123     val gen_size_def_bs = map (mk_def_binding I) gen_size_names;
   124 
   125     val (gen_size_defs, thy2) =
   126       thy
   127       |> Sign.add_consts (map (fn (s, T) => (Binding.name (Long_Name.base_name s), T, NoSyn))
   128         (gen_size_names ~~ gen_size_Ts))
   129       |> Global_Theory.add_defs false (map Thm.no_attributes (gen_size_def_bs ~~
   130         map Logic.mk_equals (gen_size_consts ~~ gen_size_rhss)));
   131 
   132     val zeros = map mk_abs_zero_nat As;
   133 
   134     val spec_size_rhss = map (fn c => Term.list_comb (c, zeros)) gen_size_consts;
   135     val spec_size_Ts = map fastype_of spec_size_rhss;
   136     val spec_size_consts = map (curry Const @{const_name size}) spec_size_Ts;
   137     val spec_size_def_bs = map (mk_def_binding (suffix "_overloaded")) gen_size_names;
   138 
   139     fun define_spec_size def_b lhs0 rhs lthy =
   140       let
   141         val Free (c, _) = Syntax.check_term lthy lhs0;
   142         val (thm, lthy') = lthy
   143           |> Local_Theory.define ((Binding.name c, NoSyn), ((def_b, []), rhs))
   144           |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
   145         val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
   146         val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
   147       in (thm', lthy') end;
   148 
   149     val (spec_size_defs, thy3) = thy2
   150       |> Class.instantiation (T_names, map dest_TFree As, [HOLogic.class_size])
   151       |> fold_map3 define_spec_size spec_size_def_bs spec_size_consts spec_size_rhss
   152       ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
   153       ||> Local_Theory.exit_global;
   154 
   155     val thy3_ctxt = Proof_Context.init_global thy3;
   156 
   157     val gen_size_defs' =
   158       map (mk_unabs_def (num_As + 1) o (fn thm => thm RS meta_eq_to_obj_eq)) gen_size_defs;
   159     val spec_size_defs' =
   160       map (mk_unabs_def 1 o (fn thm => thm RS meta_eq_to_obj_eq)) spec_size_defs;
   161 
   162     fun derive_size_simp unfolds folds size_def' simp0 =
   163       fold_thms thy3_ctxt folds (unfold_thms thy3_ctxt unfolds (trans OF [size_def', simp0]));
   164     val derive_gen_size_simp =
   165       derive_size_simp (@{thm snd_o_convol} :: nested_gen_size_maps) gen_size_defs';
   166     val derive_spec_size_simp = derive_size_simp @{thms add_0_left add_0_right} spec_size_defs';
   167 
   168     val gen_size_simpss = map2 (map o derive_gen_size_simp) gen_size_defs' rec_thmss;
   169     val gen_size_simps = flat gen_size_simpss;
   170     val spec_size_simpss = map2 (map o derive_spec_size_simp) spec_size_defs' gen_size_simpss;
   171 
   172     val ABs = As ~~ Bs;
   173     val g_names = map (prefix "g" o string_of_int) (1 upto num_As);
   174     val gs = map2 (curry Free) g_names (map (op -->) ABs);
   175 
   176     val liveness = map (op <>) ABs;
   177     val live_gs = AList.find (op =) (gs ~~ liveness) true;
   178     val live = length live_gs;
   179 
   180     val u_names = map (prefix "u" o string_of_int) (1 upto nn);
   181     val us = map2 (curry Free) u_names Ts;
   182 
   183     val maps0 = map map_of_bnf fp_bnfs;
   184     val map_thms = maps #maps fp_sugars;
   185 
   186     fun mk_gen_size_map_tac ctxt =
   187       HEADGOAL (rtac (co_induct_of common_inducts)) THEN
   188       ALLGOALS (asm_simp_tac (ss_only (o_apply :: map_thms @ gen_size_simps) ctxt));
   189 
   190     val gen_size_map_thmss =
   191       if live = 0 then
   192         replicate nn []
   193       else if null nested_gen_size_maps then
   194         let
   195           val xgmaps =
   196             map2 (fn map0 => fn u => Term.list_comb (mk_map live As Bs map0, live_gs) $ u) maps0 us;
   197           val fsizes =
   198             map (fn gen_size_constB => Term.list_comb (gen_size_constB, fsB)) gen_size_constsB;
   199           val lhss = map2 (curry (op $)) fsizes xgmaps;
   200 
   201           val fgs = map2 (fn fB => fn g as Free (_, Type (_, [A, B])) =>
   202             if A = B then fB else HOLogic.mk_comp (fB, g)) fsB gs;
   203           val rhss = map2 (fn gen_size_const => fn u => Term.list_comb (gen_size_const, fgs) $ u)
   204             gen_size_consts us;
   205 
   206           val goal = Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) lhss rhss)
   207             |> HOLogic.mk_Trueprop;
   208         in
   209           Goal.prove_global thy3 [] [] goal (mk_gen_size_map_tac o #context)
   210           |> Thm.close_derivation
   211           |> conj_dests nn
   212           |> map single
   213         end
   214       else
   215         (* TODO: implement general case, with nesting of datatypes that themselves nest other
   216            types *)
   217         replicate nn [];
   218 
   219     val (_, thy4) = thy3
   220       |> fold_map3 (fn T_name => fn size_simps => fn gen_size_map_thms =>
   221           let val qualify = Binding.qualify true (Long_Name.base_name T_name) in
   222             Global_Theory.note_thmss ""
   223               ([((qualify (Binding.name sizeN),
   224                    [Simplifier.simp_add, Nitpick_Simps.add, Thm.declaration_attribute
   225                       (fn thm => Context.mapping (Code.add_default_eqn thm) I)]),
   226                  [(size_simps, [])]),
   227                 ((qualify (Binding.name size_mapN), []), [(gen_size_map_thms, [])])]
   228                |> filter_out (forall (null o fst) o snd))
   229           end)
   230         T_names (map2 append gen_size_simpss spec_size_simpss) gen_size_map_thmss
   231       ||> Spec_Rules.add_global Spec_Rules.Equational (gen_size_consts, gen_size_simps);
   232   in
   233     thy4
   234     |> Data.map (fold2 (fn T_name => fn gen_size_name =>
   235         Symtab.update_new (T_name, (gen_size_name, (gen_size_simps, flat gen_size_map_thmss))))
   236       T_names gen_size_names)
   237   end;
   238 
   239 val _ = Theory.setup (fp_sugar_interpretation generate_size);
   240 
   241 end;