src/HOL/Tools/BNF/bnf_lfp_size.ML
author desharna
Mon Oct 06 13:37:38 2014 +0200 (2014-10-06)
changeset 58578 9ff8ca957c02
parent 58461 75ee8d49c724
child 58634 9f10d82e8188
permissions -rw-r--r--
rename 'xtor_co_rec_o_map_thms' to 'xtor_co_rec_o_maps'
     1 (*  Title:      HOL/Tools/BNF/bnf_lfp_size.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2014
     4 
     5 Generation of size functions for datatypes.
     6 *)
     7 
     8 signature BNF_LFP_SIZE =
     9 sig
    10   val size_plugin: string
    11   val register_size: string -> string -> thm list -> thm list -> local_theory -> local_theory
    12   val register_size_global: string -> string -> thm list -> thm list -> theory -> theory
    13   val size_of: Proof.context -> string -> (string * (thm list * thm list)) option
    14   val size_of_global: theory -> string -> (string * (thm list * thm list)) option
    15 end;
    16 
    17 structure BNF_LFP_Size : BNF_LFP_SIZE =
    18 struct
    19 
    20 open BNF_Util
    21 open BNF_Tactics
    22 open BNF_Def
    23 open BNF_FP_Def_Sugar
    24 
    25 val size_plugin = "size";
    26 
    27 val size_N = "size_";
    28 
    29 val rec_o_mapN = "rec_o_map";
    30 val sizeN = "size";
    31 val size_o_mapN = "size_o_map";
    32 
    33 val nitpicksimp_attrs = @{attributes [nitpick_simp]};
    34 val simp_attrs = @{attributes [simp]};
    35 
    36 structure Data = Generic_Data
    37 (
    38   type T = (string * (thm list * thm list)) Symtab.table;
    39   val empty = Symtab.empty;
    40   val extend = I
    41   fun merge data = Symtab.merge (K true) data;
    42 );
    43 
    44 fun register_size T_name size_name size_simps size_o_maps =
    45   Context.proof_map (Data.map (Symtab.update (T_name, (size_name, (size_simps, size_o_maps)))));
    46 
    47 fun register_size_global T_name size_name size_simps size_o_maps =
    48   Context.theory_map (Data.map (Symtab.update (T_name, (size_name, (size_simps, size_o_maps)))));
    49 
    50 val size_of = Symtab.lookup o Data.get o Context.Proof;
    51 val size_of_global = Symtab.lookup o Data.get o Context.Theory;
    52 
    53 fun mk_plus_nat (t1, t2) = Const (@{const_name Groups.plus},
    54   HOLogic.natT --> HOLogic.natT --> HOLogic.natT) $ t1 $ t2;
    55 
    56 fun mk_to_natT T = T --> HOLogic.natT;
    57 
    58 fun mk_abs_zero_nat T = Term.absdummy T HOLogic.zero;
    59 
    60 fun mk_pointfull ctxt th = unfold_thms ctxt [o_apply] (th RS fun_cong);
    61 
    62 fun mk_unabs_def_unused_0 n =
    63   funpow n (fn thm => thm RS @{thm fun_cong_unused_0} handle THM _ => thm RS fun_cong);
    64 
    65 val rec_o_map_simps =
    66   @{thms o_def[abs_def] id_def case_prod_app case_sum_map_sum case_prod_map_prod id_bnf_def};
    67 
    68 fun mk_rec_o_map_tac ctxt rec_def pre_map_defs live_nesting_map_ident0s abs_inverses
    69     ctor_rec_o_map =
    70   HEADGOAL (subst_tac @{context} (SOME [1, 2]) [rec_def] THEN' rtac (ctor_rec_o_map RS trans) THEN'
    71     CONVERSION Thm.eta_long_conversion THEN'
    72     asm_simp_tac (ss_only (pre_map_defs @
    73         distinct Thm.eq_thm_prop (live_nesting_map_ident0s @ abs_inverses) @ rec_o_map_simps)
    74       ctxt));
    75 
    76 val size_o_map_simps = @{thms prod_inj_map inj_on_id snd_comp_apfst[unfolded apfst_def]};
    77 
    78 fun mk_size_o_map_tac ctxt size_def rec_o_map inj_maps size_maps =
    79   unfold_thms_tac ctxt [size_def] THEN
    80   HEADGOAL (rtac (rec_o_map RS trans) THEN'
    81     asm_simp_tac (ss_only (inj_maps @ size_maps @ size_o_map_simps) ctxt)) THEN
    82   IF_UNSOLVED (unfold_thms_tac ctxt @{thms o_def} THEN HEADGOAL (rtac refl));
    83 
    84 fun generate_datatype_size (fp_sugars as ({T = Type (_, As), BT = Type (_, Bs), fp = Least_FP,
    85       fp_res = {bnfs = fp_bnfs, xtor_co_rec_o_maps = ctor_rec_o_maps, ...}, fp_nesting_bnfs,
    86       live_nesting_bnfs, ...} : fp_sugar) :: _) lthy0 =
    87     let
    88       val data = Data.get (Context.Proof lthy0);
    89 
    90       val Ts = map #T fp_sugars
    91       val T_names = map (fst o dest_Type) Ts;
    92       val nn = length Ts;
    93 
    94       val B_ify = Term.typ_subst_atomic (As ~~ Bs);
    95 
    96       val recs = map (#co_rec o #fp_co_induct_sugar) fp_sugars;
    97       val rec_thmss = map (#co_rec_thms o #fp_co_induct_sugar) fp_sugars;
    98       val rec_Ts as rec_T1 :: _ = map fastype_of recs;
    99       val rec_arg_Ts = binder_fun_types rec_T1;
   100       val Cs = map body_type rec_Ts;
   101       val Cs_rho = map (rpair HOLogic.natT) Cs;
   102       val substCnatT = Term.subst_atomic_types Cs_rho;
   103 
   104       val f_Ts = map mk_to_natT As;
   105       val f_TsB = map mk_to_natT Bs;
   106       val num_As = length As;
   107 
   108       fun variant_names n pre = fst (Variable.variant_fixes (replicate n pre) lthy0);
   109 
   110       val f_names = variant_names num_As "f";
   111       val fs = map2 (curry Free) f_names f_Ts;
   112       val fsB = map2 (curry Free) f_names f_TsB;
   113       val As_fs = As ~~ fs;
   114 
   115       val size_bs =
   116         map ((fn base => Binding.qualify false base (Binding.name (prefix size_N base))) o
   117           Long_Name.base_name) T_names;
   118 
   119       fun is_prod_C @{type_name prod} [_, T'] = member (op =) Cs T'
   120         | is_prod_C _ _ = false;
   121 
   122       fun mk_size_of_typ (T as TFree _) =
   123           pair (case AList.lookup (op =) As_fs T of
   124               SOME f => f
   125             | NONE => if member (op =) Cs T then Term.absdummy T (Bound 0) else mk_abs_zero_nat T)
   126         | mk_size_of_typ (T as Type (s, Ts)) =
   127           if is_prod_C s Ts then
   128             pair (snd_const T)
   129           else if exists (exists_subtype_in (As @ Cs)) Ts then
   130             (case Symtab.lookup data s of
   131               SOME (size_name, (_, size_o_maps)) =>
   132               let
   133                 val (args, size_o_mapss') = fold_map mk_size_of_typ Ts [];
   134                 val size_const = Const (size_name, map fastype_of args ---> mk_to_natT T);
   135               in
   136                 append (size_o_maps :: size_o_mapss')
   137                 #> pair (Term.list_comb (size_const, args))
   138               end
   139             | _ => pair (mk_abs_zero_nat T))
   140           else
   141             pair (mk_abs_zero_nat T);
   142 
   143       fun mk_size_of_arg t =
   144         mk_size_of_typ (fastype_of t) #>> (fn s => substCnatT (betapply (s, t)));
   145 
   146       fun is_recursive_or_plain_case Ts =
   147         exists (exists_subtype_in Cs) Ts orelse forall (not o exists_subtype_in As) Ts;
   148 
   149       (* We want the size function to enjoy the following properties:
   150 
   151           1. The size of a list should coincide with its length.
   152           2. All the nonrecursive constructors of a type should have the same size.
   153           3. Each constructor through which nested recursion takes place should count as at least
   154              one in the generic size function.
   155           4. The "size" function should be definable as "size_t (%_. 0) ... (%_. 0)", where "size_t"
   156              is the generic function.
   157 
   158          This explains the somewhat convoluted logic ahead. *)
   159 
   160       val base_case =
   161         if forall (is_recursive_or_plain_case o binder_types) rec_arg_Ts then HOLogic.zero
   162         else HOLogic.Suc_zero;
   163 
   164       fun mk_size_arg rec_arg_T =
   165         let
   166           val x_Ts = binder_types rec_arg_T;
   167           val m = length x_Ts;
   168           val x_names = variant_names m "x";
   169           val xs = map2 (curry Free) x_names x_Ts;
   170           val (summands, size_o_mapss) =
   171             fold_map mk_size_of_arg xs []
   172             |>> remove (op =) HOLogic.zero;
   173           val sum =
   174             if null summands then base_case
   175             else foldl1 mk_plus_nat (summands @ [HOLogic.Suc_zero]);
   176         in
   177           append size_o_mapss
   178           #> pair (fold_rev Term.lambda (map substCnatT xs) sum)
   179         end;
   180 
   181       fun mk_size_rhs recx =
   182         fold_map mk_size_arg rec_arg_Ts
   183         #>> (fn args => fold_rev Term.lambda fs (Term.list_comb (substCnatT recx, args)));
   184 
   185       val maybe_conceal_def_binding = Thm.def_binding
   186         #> not (Config.get lthy0 bnf_note_all) ? Binding.conceal;
   187 
   188       val (size_rhss, nested_size_o_mapss) = fold_map mk_size_rhs recs [];
   189       val size_Ts = map fastype_of size_rhss;
   190 
   191       val nested_size_o_maps_complete = forall (not o null) nested_size_o_mapss;
   192       val nested_size_o_maps = fold (union Thm.eq_thm_prop) nested_size_o_mapss [];
   193 
   194       val ((raw_size_consts, raw_size_defs), (lthy1', lthy1)) = lthy0
   195         |> apfst split_list o fold_map2 (fn b => fn rhs =>
   196             Local_Theory.define ((b, NoSyn), ((maybe_conceal_def_binding b, []), rhs))
   197             #>> apsnd snd)
   198           size_bs size_rhss
   199         ||> `Local_Theory.restore;
   200 
   201       val phi = Proof_Context.export_morphism lthy1 lthy1';
   202 
   203       val size_defs = map (Morphism.thm phi) raw_size_defs;
   204 
   205       val size_consts0 = map (Morphism.term phi) raw_size_consts;
   206       val size_consts = map2 retype_const_or_free size_Ts size_consts0;
   207       val size_constsB = map (Term.map_types B_ify) size_consts;
   208 
   209       val zeros = map mk_abs_zero_nat As;
   210 
   211       val overloaded_size_rhss = map (fn c => Term.list_comb (c, zeros)) size_consts;
   212       val overloaded_size_Ts = map fastype_of overloaded_size_rhss;
   213       val overloaded_size_consts = map (curry Const @{const_name size}) overloaded_size_Ts;
   214       val overloaded_size_def_bs =
   215         map (maybe_conceal_def_binding o Binding.suffix_name "_overloaded") size_bs;
   216 
   217       fun define_overloaded_size def_b lhs0 rhs lthy =
   218         let
   219           val Free (c, _) = Syntax.check_term lthy lhs0;
   220           val (thm, lthy') = lthy
   221             |> Local_Theory.define ((Binding.name c, NoSyn), ((def_b, []), rhs))
   222             |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
   223           val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
   224           val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
   225         in (thm', lthy') end;
   226 
   227       val (overloaded_size_defs, lthy2) = lthy1
   228         |> Local_Theory.background_theory_result
   229           (Class.instantiation (T_names, map dest_TFree As, [HOLogic.class_size])
   230            #> fold_map3 define_overloaded_size overloaded_size_def_bs overloaded_size_consts
   231              overloaded_size_rhss
   232            ##> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
   233            ##> Local_Theory.exit_global);
   234 
   235       val size_defs' =
   236         map (mk_unabs_def (num_As + 1) o (fn thm => thm RS meta_eq_to_obj_eq)) size_defs;
   237       val size_defs_unused_0 =
   238         map (mk_unabs_def_unused_0 (num_As + 1) o (fn thm => thm RS meta_eq_to_obj_eq)) size_defs;
   239       val overloaded_size_defs' =
   240         map (mk_unabs_def 1 o (fn thm => thm RS meta_eq_to_obj_eq)) overloaded_size_defs;
   241 
   242       val all_overloaded_size_defs = overloaded_size_defs @
   243         (Spec_Rules.retrieve lthy0 @{const size ('a)}
   244          |> map_filter (try (fn (Spec_Rules.Equational, (_, [thm])) => thm)));
   245 
   246       val nested_size_maps = map (mk_pointfull lthy2) nested_size_o_maps @ nested_size_o_maps;
   247       val all_inj_maps = map inj_map_of_bnf (fp_bnfs @ fp_nesting_bnfs @ live_nesting_bnfs)
   248         |> distinct Thm.eq_thm_prop;
   249 
   250       fun derive_size_simp size_def' simp0 =
   251         (trans OF [size_def', simp0])
   252         |> Simplifier.asm_full_simplify (ss_only (@{thms inj_on_convol_ident id_def o_def
   253           snd_conv} @ all_inj_maps @ nested_size_maps) lthy2)
   254         |> fold_thms lthy2 size_defs_unused_0;
   255 
   256       fun derive_overloaded_size_simp overloaded_size_def' simp0 =
   257         (trans OF [overloaded_size_def', simp0])
   258         |> unfold_thms lthy2 @{thms add_0_left add_0_right}
   259         |> fold_thms lthy2 all_overloaded_size_defs;
   260 
   261       val size_simpss = map2 (map o derive_size_simp) size_defs' rec_thmss;
   262       val size_simps = flat size_simpss;
   263       val overloaded_size_simpss =
   264         map2 (map o derive_overloaded_size_simp) overloaded_size_defs' size_simpss;
   265       val size_thmss = map2 append size_simpss overloaded_size_simpss;
   266 
   267       val ABs = As ~~ Bs;
   268       val g_names = variant_names num_As "g";
   269       val gs = map2 (curry Free) g_names (map (op -->) ABs);
   270 
   271       val liveness = map (op <>) ABs;
   272       val live_gs = AList.find (op =) (gs ~~ liveness) true;
   273       val live = length live_gs;
   274 
   275       val maps0 = map map_of_bnf fp_bnfs;
   276 
   277       val (rec_o_map_thmss, size_o_map_thmss) =
   278         if live = 0 then
   279           `I (replicate nn [])
   280         else
   281           let
   282             val pre_bnfs = map #pre_bnf fp_sugars;
   283             val pre_map_defs = map map_def_of_bnf pre_bnfs;
   284             val live_nesting_map_ident0s = map map_ident0_of_bnf live_nesting_bnfs;
   285             val abs_inverses = map (#abs_inverse o #absT_info) fp_sugars;
   286             val rec_defs = map (#co_rec_def o #fp_co_induct_sugar) fp_sugars;
   287 
   288             val gmaps = map (fn map0 => Term.list_comb (mk_map live As Bs map0, live_gs)) maps0;
   289 
   290             val num_rec_args = length rec_arg_Ts;
   291             val h_Ts = map B_ify rec_arg_Ts;
   292             val h_names = variant_names num_rec_args "h";
   293             val hs = map2 (curry Free) h_names h_Ts;
   294             val hrecs = map (fn recx => Term.list_comb (Term.map_types B_ify recx, hs)) recs;
   295 
   296             val rec_o_map_lhss = map2 (curry HOLogic.mk_comp) hrecs gmaps;
   297 
   298             val ABgs = ABs ~~ gs;
   299 
   300             fun mk_rec_arg_arg (x as Free (_, T)) =
   301               let val U = B_ify T in
   302                 if T = U then x else build_map lthy2 [] (the o AList.lookup (op =) ABgs) (T, U) $ x
   303               end;
   304 
   305             fun mk_rec_o_map_arg rec_arg_T h =
   306               let
   307                 val x_Ts = binder_types rec_arg_T;
   308                 val m = length x_Ts;
   309                 val x_names = variant_names m "x";
   310                 val xs = map2 (curry Free) x_names x_Ts;
   311                 val xs' = map mk_rec_arg_arg xs;
   312               in
   313                 fold_rev Term.lambda xs (Term.list_comb (h, xs'))
   314               end;
   315 
   316             fun mk_rec_o_map_rhs recx =
   317               let val args = map2 mk_rec_o_map_arg rec_arg_Ts hs in
   318                 Term.list_comb (recx, args)
   319               end;
   320 
   321             val rec_o_map_rhss = map mk_rec_o_map_rhs recs;
   322 
   323             val rec_o_map_goals =
   324               map2 (fold_rev (fold_rev Logic.all) [gs, hs] o HOLogic.mk_Trueprop oo
   325                 curry HOLogic.mk_eq) rec_o_map_lhss rec_o_map_rhss;
   326             val rec_o_map_thms =
   327               map3 (fn goal => fn rec_def => fn ctor_rec_o_map =>
   328                   Goal.prove_sorry lthy2 [] [] goal (fn {context = ctxt, ...} =>
   329                     mk_rec_o_map_tac ctxt rec_def pre_map_defs live_nesting_map_ident0s abs_inverses
   330                       ctor_rec_o_map)
   331                   |> Thm.close_derivation)
   332                 rec_o_map_goals rec_defs ctor_rec_o_maps;
   333 
   334             val size_o_map_conds =
   335               if exists (can Logic.dest_implies o Thm.prop_of) nested_size_o_maps then
   336                 map (HOLogic.mk_Trueprop o mk_inj) live_gs
   337               else
   338                 [];
   339 
   340             val fsizes = map (fn size_constB => Term.list_comb (size_constB, fsB)) size_constsB;
   341             val size_o_map_lhss = map2 (curry HOLogic.mk_comp) fsizes gmaps;
   342 
   343             val fgs = map2 (fn fB => fn g as Free (_, Type (_, [A, B])) =>
   344               if A = B then fB else HOLogic.mk_comp (fB, g)) fsB gs;
   345             val size_o_map_rhss = map (fn c => Term.list_comb (c, fgs)) size_consts;
   346 
   347             val size_o_map_goals =
   348               map2 (fold_rev (fold_rev Logic.all) [fsB, gs] o
   349                 curry Logic.list_implies size_o_map_conds o HOLogic.mk_Trueprop oo
   350                 curry HOLogic.mk_eq) size_o_map_lhss size_o_map_rhss;
   351 
   352             val size_o_map_thmss =
   353               if nested_size_o_maps_complete then
   354                 map3 (fn goal => fn size_def => fn rec_o_map =>
   355                     Goal.prove_sorry lthy2 [] [] goal (fn {context = ctxt, ...} =>
   356                       mk_size_o_map_tac ctxt size_def rec_o_map all_inj_maps nested_size_maps)
   357                     |> Thm.close_derivation
   358                     |> single)
   359                   size_o_map_goals size_defs rec_o_map_thms
   360               else
   361                 replicate nn [];
   362           in
   363             (map single rec_o_map_thms, size_o_map_thmss)
   364           end;
   365 
   366       (* Ideally, the "[code]" attribute would be generated only if the "code" plugin is enabled. *)
   367       val code_attrs = Code.add_default_eqn_attrib;
   368 
   369       val massage_multi_notes =
   370         maps (fn (thmN, thmss, attrs) =>
   371           map2 (fn T_name => fn thms =>
   372               ((Binding.qualify true (Long_Name.base_name T_name) (Binding.name thmN), attrs),
   373                [(thms, [])]))
   374             T_names thmss)
   375         #> filter_out (null o fst o hd o snd);
   376 
   377       val notes =
   378         [(rec_o_mapN, rec_o_map_thmss, []),
   379          (sizeN, size_thmss, code_attrs :: nitpicksimp_attrs @ simp_attrs),
   380          (size_o_mapN, size_o_map_thmss, [])]
   381         |> massage_multi_notes;
   382 
   383       val (noted, lthy3) =
   384         lthy2
   385         |> Spec_Rules.add Spec_Rules.Equational (size_consts, size_simps)
   386         |> Local_Theory.notes notes;
   387 
   388       val phi0 = substitute_noted_thm noted;
   389     in
   390       lthy3
   391       |> Local_Theory.declaration {syntax = false, pervasive = true}
   392         (fn phi => Data.map (fold2 (fn T_name => fn Const (size_name, _) =>
   393              Symtab.update (T_name, (size_name,
   394                pairself (map (Morphism.thm (phi0 $> phi))) (size_simps, flat size_o_map_thmss))))
   395            T_names size_consts))
   396     end
   397   | generate_datatype_size _ lthy = lthy;
   398 
   399 val _ = Theory.setup (fp_sugars_interpretation size_plugin
   400   (map_local_theory o generate_datatype_size) generate_datatype_size);
   401 
   402 end;