src/HOL/Tools/BNF/bnf_lfp_size.ML
author blanchet
Wed Apr 23 10:23:27 2014 +0200 (2014-04-23)
changeset 56654 54326fa7afe6
parent 56651 fc105315822a
child 56682 d39926ff0487
permissions -rw-r--r--
qualify name
     1 (*  Title:      HOL/Tools/BNF/bnf_lfp_size.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2014
     4 
     5 Generation of size functions for new-style datatypes.
     6 *)
     7 
     8 signature BNF_LFP_SIZE =
     9 sig
    10   val register_size: string -> string -> thm list -> thm list -> local_theory -> local_theory
    11   val register_size_global: string -> string -> thm list -> thm list -> theory -> theory
    12   val lookup_size: Proof.context -> string -> (string * (thm list * thm list)) option
    13   val lookup_size_global: theory -> string -> (string * (thm list * thm list)) option
    14   val generate_lfp_size: BNF_FP_Util.fp_sugar list -> local_theory -> local_theory
    15 end;
    16 
    17 structure BNF_LFP_Size : BNF_LFP_SIZE =
    18 struct
    19 
    20 open BNF_Util
    21 open BNF_Tactics
    22 open BNF_Def
    23 open BNF_FP_Util
    24 
    25 val size_N = "size_"
    26 
    27 val rec_o_mapN = "rec_o_map"
    28 val sizeN = "size"
    29 val size_o_mapN = "size_o_map"
    30 
    31 val nitpicksimp_attrs = @{attributes [nitpick_simp]};
    32 val simp_attrs = @{attributes [simp]};
    33 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
    34 
    35 structure Data = Generic_Data
    36 (
    37   type T = (string * (thm list * thm list)) Symtab.table;
    38   val empty = Symtab.empty;
    39   val extend = I
    40   fun merge data = Symtab.merge (K true) data;
    41 );
    42 
    43 fun register_size T_name size_name size_simps size_o_maps =
    44   Context.proof_map (Data.map (Symtab.update (T_name, (size_name, (size_simps, size_o_maps)))));
    45 
    46 fun register_size_global T_name size_name size_simps size_o_maps =
    47   Context.theory_map (Data.map (Symtab.update (T_name, (size_name, (size_simps, size_o_maps)))));
    48 
    49 val lookup_size = Symtab.lookup o Data.get o Context.Proof;
    50 val lookup_size_global = Symtab.lookup o Data.get o Context.Theory;
    51 
    52 val zero_nat = @{const zero_class.zero (nat)};
    53 
    54 fun mk_plus_nat (t1, t2) = Const (@{const_name Groups.plus},
    55   HOLogic.natT --> HOLogic.natT --> HOLogic.natT) $ t1 $ t2;
    56 
    57 fun mk_to_natT T = T --> HOLogic.natT;
    58 
    59 fun mk_abs_zero_nat T = Term.absdummy T zero_nat;
    60 
    61 fun pointfill ctxt th = unfold_thms ctxt [o_apply] (th RS fun_cong);
    62 
    63 fun mk_unabs_def_unused_0 n =
    64   funpow n (fn thm => thm RS @{thm fun_cong_unused_0} handle THM _ => thm RS fun_cong);
    65 
    66 val rec_o_map_simp_thms =
    67   @{thms o_def id_def case_prod_app case_sum_map_sum case_prod_map_prod BNF_Comp.id_bnf_comp_def};
    68 
    69 fun mk_rec_o_map_tac ctxt rec_def pre_map_defs abs_inverses ctor_rec_o_map =
    70   unfold_thms_tac ctxt [rec_def] THEN
    71   HEADGOAL (rtac (ctor_rec_o_map RS trans)) THEN
    72   PRIMITIVE (Conv.fconv_rule Thm.eta_long_conversion) THEN
    73   HEADGOAL (asm_simp_tac (ss_only (pre_map_defs @ distinct Thm.eq_thm_prop abs_inverses @
    74     rec_o_map_simp_thms) ctxt));
    75 
    76 val size_o_map_simp_thms = @{thms prod_inj_map inj_on_id snd_comp_apfst[unfolded apfst_def]};
    77 
    78 fun mk_size_o_map_tac ctxt size_def rec_o_map inj_maps size_maps =
    79   unfold_thms_tac ctxt [size_def] THEN
    80   HEADGOAL (rtac (rec_o_map RS trans) THEN'
    81     asm_simp_tac (ss_only (inj_maps @ size_maps @ size_o_map_simp_thms) ctxt)) THEN
    82   IF_UNSOLVED (unfold_thms_tac ctxt @{thms o_def} THEN HEADGOAL (rtac refl));
    83 
    84 fun generate_lfp_size (fp_sugars as ({T = Type (_, As), BT = Type (_, Bs),
    85     fp_res = {bnfs = fp_bnfs, xtor_co_rec_o_map_thms = ctor_rec_o_maps, ...}, nested_bnfs,
    86     nesting_bnfs, ...} : fp_sugar) :: _) lthy0 =
    87   let
    88     val data = Data.get (Context.Proof lthy0);
    89 
    90     val Ts = map #T fp_sugars
    91     val T_names = map (fst o dest_Type) Ts;
    92     val nn = length Ts;
    93 
    94     val B_ify = Term.typ_subst_atomic (As ~~ Bs);
    95 
    96     val recs = map #co_rec fp_sugars;
    97     val rec_thmss = map #co_rec_thms fp_sugars;
    98     val rec_Ts as rec_T1 :: _ = map fastype_of recs;
    99     val rec_arg_Ts = binder_fun_types rec_T1;
   100     val Cs = map body_type rec_Ts;
   101     val Cs_rho = map (rpair HOLogic.natT) Cs;
   102     val substCnatT = Term.subst_atomic_types Cs_rho;
   103 
   104     val f_Ts = map mk_to_natT As;
   105     val f_TsB = map mk_to_natT Bs;
   106     val num_As = length As;
   107 
   108     fun variant_names n pre = fst (Variable.variant_fixes (replicate n pre) lthy0);
   109 
   110     val f_names = variant_names num_As "f";
   111     val fs = map2 (curry Free) f_names f_Ts;
   112     val fsB = map2 (curry Free) f_names f_TsB;
   113     val As_fs = As ~~ fs;
   114 
   115     val size_bs =
   116       map ((fn base => Binding.qualify false base (Binding.name (prefix size_N base))) o
   117         Long_Name.base_name) T_names;
   118 
   119     fun is_pair_C @{type_name prod} [_, T'] = member (op =) Cs T'
   120       | is_pair_C _ _ = false;
   121 
   122     fun mk_size_of_typ (T as TFree _) =
   123         pair (case AList.lookup (op =) As_fs T of
   124             SOME f => f
   125           | NONE => if member (op =) Cs T then Term.absdummy T (Bound 0) else mk_abs_zero_nat T)
   126       | mk_size_of_typ (T as Type (s, Ts)) =
   127         if is_pair_C s Ts then
   128           pair (snd_const T)
   129         else if exists (exists_subtype_in As) Ts then
   130           (case Symtab.lookup data s of
   131             SOME (size_name, (_, size_o_maps as _ :: _)) =>
   132             let
   133               val (args, size_o_mapss') = split_list (map (fn T => mk_size_of_typ T []) Ts);
   134               val size_const = Const (size_name, map fastype_of args ---> mk_to_natT T);
   135             in
   136               fold (union Thm.eq_thm) (size_o_maps :: size_o_mapss')
   137               #> pair (Term.list_comb (size_const, args))
   138             end
   139           | _ => pair (mk_abs_zero_nat T))
   140         else
   141           pair (mk_abs_zero_nat T);
   142 
   143     fun mk_size_of_arg t =
   144       mk_size_of_typ (fastype_of t) #>> (fn s => substCnatT (betapply (s, t)));
   145 
   146     fun mk_size_arg rec_arg_T size_o_maps =
   147       let
   148         val x_Ts = binder_types rec_arg_T;
   149         val m = length x_Ts;
   150         val x_names = variant_names m "x";
   151         val xs = map2 (curry Free) x_names x_Ts;
   152         val (summands, size_o_maps') =
   153           fold_map mk_size_of_arg xs size_o_maps
   154           |>> remove (op =) zero_nat;
   155         val sum =
   156           if null summands then HOLogic.zero
   157           else foldl1 mk_plus_nat (summands @ [HOLogic.Suc_zero]);
   158       in
   159         (fold_rev Term.lambda (map substCnatT xs) sum, size_o_maps')
   160       end;
   161 
   162     fun mk_size_rhs recx size_o_maps =
   163       let val (args, size_o_maps') = fold_map mk_size_arg rec_arg_Ts size_o_maps in
   164         (fold_rev Term.lambda fs (Term.list_comb (substCnatT recx, args)), size_o_maps')
   165       end;
   166 
   167     val maybe_conceal_def_binding = Thm.def_binding
   168       #> Config.get lthy0 bnf_note_all = false ? Binding.conceal;
   169 
   170     val (size_rhss, nested_size_o_maps) = fold_map mk_size_rhs recs [];
   171     val size_Ts = map fastype_of size_rhss;
   172 
   173     val ((raw_size_consts, raw_size_defs), (lthy1', lthy1)) = lthy0
   174       |> apfst split_list o fold_map2 (fn b => fn rhs =>
   175           Local_Theory.define ((b, NoSyn), ((maybe_conceal_def_binding b, []), rhs)) #>> apsnd snd)
   176         size_bs size_rhss
   177       ||> `Local_Theory.restore;
   178 
   179     val phi = Proof_Context.export_morphism lthy1 lthy1';
   180 
   181     val size_defs = map (Morphism.thm phi) raw_size_defs;
   182 
   183     val size_consts0 = map (Morphism.term phi) raw_size_consts;
   184     val size_consts = map2 retype_const_or_free size_Ts size_consts0;
   185     val size_constsB = map (Term.map_types B_ify) size_consts;
   186 
   187     val zeros = map mk_abs_zero_nat As;
   188 
   189     val overloaded_size_rhss = map (fn c => Term.list_comb (c, zeros)) size_consts;
   190     val overloaded_size_Ts = map fastype_of overloaded_size_rhss;
   191     val overloaded_size_consts = map (curry Const @{const_name size}) overloaded_size_Ts;
   192     val overloaded_size_def_bs =
   193       map (maybe_conceal_def_binding o Binding.suffix_name "_overloaded") size_bs;
   194 
   195     fun define_overloaded_size def_b lhs0 rhs lthy =
   196       let
   197         val Free (c, _) = Syntax.check_term lthy lhs0;
   198         val (thm, lthy') = lthy
   199           |> Local_Theory.define ((Binding.name c, NoSyn), ((def_b, []), rhs))
   200           |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
   201         val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
   202         val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
   203       in (thm', lthy') end;
   204 
   205     val (overloaded_size_defs, lthy2) = lthy1
   206       |> Local_Theory.background_theory_result
   207         (Class.instantiation (T_names, map dest_TFree As, [HOLogic.class_size])
   208          #> fold_map3 define_overloaded_size overloaded_size_def_bs overloaded_size_consts
   209            overloaded_size_rhss
   210          ##> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
   211          ##> Local_Theory.exit_global);
   212 
   213     val size_defs' =
   214       map (mk_unabs_def (num_As + 1) o (fn thm => thm RS meta_eq_to_obj_eq)) size_defs;
   215     val size_defs_unused_0 =
   216       map (mk_unabs_def_unused_0 (num_As + 1) o (fn thm => thm RS meta_eq_to_obj_eq)) size_defs;
   217     val overloaded_size_defs' =
   218       map (mk_unabs_def 1 o (fn thm => thm RS meta_eq_to_obj_eq)) overloaded_size_defs;
   219 
   220     val nested_size_maps = map (pointfill lthy2) nested_size_o_maps @ nested_size_o_maps;
   221     val all_inj_maps = map inj_map_of_bnf (fp_bnfs @ nested_bnfs @ nesting_bnfs);
   222 
   223     fun derive_size_simp size_def' simp0 =
   224       (trans OF [size_def', simp0])
   225       |> Simplifier.asm_full_simplify (ss_only (@{thms inj_on_convol_id snd_o_convol} @
   226         all_inj_maps @ nested_size_maps) lthy2)
   227       |> fold_thms lthy2 size_defs_unused_0;
   228     fun derive_overloaded_size_simp size_def' simp0 =
   229       (trans OF [size_def', simp0])
   230       |> unfold_thms lthy2 @{thms add_0_left add_0_right}
   231       |> fold_thms lthy2 overloaded_size_defs';
   232 
   233     val size_simpss = map2 (map o derive_size_simp) size_defs' rec_thmss;
   234     val size_simps = flat size_simpss;
   235     val overloaded_size_simpss =
   236       map2 (map o derive_overloaded_size_simp) overloaded_size_defs' size_simpss;
   237     val size_thmss = map2 append size_simpss overloaded_size_simpss;
   238 
   239     val ABs = As ~~ Bs;
   240     val g_names = variant_names num_As "g";
   241     val gs = map2 (curry Free) g_names (map (op -->) ABs);
   242 
   243     val liveness = map (op <>) ABs;
   244     val live_gs = AList.find (op =) (gs ~~ liveness) true;
   245     val live = length live_gs;
   246 
   247     val maps0 = map map_of_bnf fp_bnfs;
   248 
   249     val (rec_o_map_thmss, size_o_map_thmss) =
   250       if live = 0 then
   251         `I (replicate nn [])
   252       else
   253         let
   254           val pre_bnfs = map #pre_bnf fp_sugars;
   255           val pre_map_defs = map map_def_of_bnf pre_bnfs;
   256           val abs_inverses = map (#abs_inverse o #absT_info) fp_sugars;
   257           val rec_defs = map #co_rec_def fp_sugars;
   258 
   259           val gmaps = map (fn map0 => Term.list_comb (mk_map live As Bs map0, live_gs)) maps0;
   260 
   261           val num_rec_args = length rec_arg_Ts;
   262           val h_Ts = map B_ify rec_arg_Ts;
   263           val h_names = variant_names num_rec_args "h";
   264           val hs = map2 (curry Free) h_names h_Ts;
   265           val hrecs = map (fn recx => Term.list_comb (Term.map_types B_ify recx, hs)) recs;
   266 
   267           val rec_o_map_lhss = map2 (curry HOLogic.mk_comp) hrecs gmaps;
   268 
   269           val ABgs = ABs ~~ gs;
   270 
   271           fun mk_rec_arg_arg (x as Free (_, T)) =
   272             let val U = B_ify T in
   273               if T = U then x else build_map lthy2 (the o AList.lookup (op =) ABgs) (T, U) $ x
   274             end;
   275 
   276           fun mk_rec_o_map_arg rec_arg_T h =
   277             let
   278               val x_Ts = binder_types rec_arg_T;
   279               val m = length x_Ts;
   280               val x_names = variant_names m "x";
   281               val xs = map2 (curry Free) x_names x_Ts;
   282               val xs' = map mk_rec_arg_arg xs;
   283             in
   284               fold_rev Term.lambda xs (Term.list_comb (h, xs'))
   285             end;
   286 
   287           fun mk_rec_o_map_rhs recx =
   288             let val args = map2 mk_rec_o_map_arg rec_arg_Ts hs in
   289               Term.list_comb (recx, args)
   290             end;
   291 
   292           val rec_o_map_rhss = map mk_rec_o_map_rhs recs;
   293 
   294           val rec_o_map_goals =
   295             map2 (fold_rev (fold_rev Logic.all) [gs, hs] o HOLogic.mk_Trueprop oo
   296               curry HOLogic.mk_eq) rec_o_map_lhss rec_o_map_rhss;
   297           val rec_o_map_thms =
   298             map3 (fn goal => fn rec_def => fn ctor_rec_o_map =>
   299                 Goal.prove lthy2 [] [] goal (fn {context = ctxt, ...} =>
   300                   mk_rec_o_map_tac ctxt rec_def pre_map_defs abs_inverses ctor_rec_o_map)
   301                 |> Thm.close_derivation)
   302               rec_o_map_goals rec_defs ctor_rec_o_maps;
   303 
   304           val size_o_map_conds =
   305             if exists (can Logic.dest_implies o Thm.prop_of) nested_size_o_maps then
   306               map (HOLogic.mk_Trueprop o mk_inj) live_gs
   307             else
   308               [];
   309 
   310           val fsizes = map (fn size_constB => Term.list_comb (size_constB, fsB)) size_constsB;
   311           val size_o_map_lhss = map2 (curry HOLogic.mk_comp) fsizes gmaps;
   312 
   313           val fgs = map2 (fn fB => fn g as Free (_, Type (_, [A, B])) =>
   314             if A = B then fB else HOLogic.mk_comp (fB, g)) fsB gs;
   315           val size_o_map_rhss = map (fn c => Term.list_comb (c, fgs)) size_consts;
   316 
   317           val size_o_map_goals =
   318             map2 (fold_rev (fold_rev Logic.all) [fsB, gs] o
   319               curry Logic.list_implies size_o_map_conds o HOLogic.mk_Trueprop oo
   320               curry HOLogic.mk_eq) size_o_map_lhss size_o_map_rhss;
   321           val size_o_map_thms =
   322             map3 (fn goal => fn size_def => fn rec_o_map =>
   323                 Goal.prove lthy2 [] [] goal (fn {context = ctxt, ...} =>
   324                   mk_size_o_map_tac ctxt size_def rec_o_map all_inj_maps nested_size_maps)
   325                 |> Thm.close_derivation)
   326               size_o_map_goals size_defs rec_o_map_thms;
   327         in
   328           pairself (map single) (rec_o_map_thms, size_o_map_thms)
   329         end;
   330 
   331     val massage_multi_notes =
   332       maps (fn (thmN, thmss, attrs) =>
   333         map2 (fn T_name => fn thms =>
   334             ((Binding.qualify true (Long_Name.base_name T_name) (Binding.name thmN), attrs),
   335              [(thms, [])]))
   336           T_names thmss)
   337       #> filter_out (null o fst o hd o snd);
   338 
   339     val notes =
   340       [(rec_o_mapN, rec_o_map_thmss, []),
   341        (sizeN, size_thmss, code_nitpicksimp_simp_attrs),
   342        (size_o_mapN, size_o_map_thmss, [])]
   343       |> massage_multi_notes;
   344   in
   345     lthy2
   346     |> Local_Theory.notes notes |> snd
   347     |> Spec_Rules.add Spec_Rules.Equational (size_consts, size_simps)
   348     |> Local_Theory.declaration {syntax = false, pervasive = true}
   349       (fn phi => Data.map (fold2 (fn T_name => fn Const (size_name, _) =>
   350            Symtab.update (T_name, (size_name,
   351              pairself (map (Morphism.thm phi)) (size_simps, flat size_o_map_thmss))))
   352          T_names size_consts))
   353   end;
   354 
   355 end;