src/HOL/Tools/BNF/bnf_lfp_size.ML
changeset 62082 614ef6d7a6b6
parent 61786 6c42d55097c1
child 62093 bd73a2279fcd
--- a/src/HOL/Tools/BNF/bnf_lfp_size.ML	Wed Jan 06 13:04:31 2016 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_size.ML	Wed Jan 06 13:04:31 2016 +0100
@@ -7,10 +7,10 @@
 
 signature BNF_LFP_SIZE =
 sig
-  val register_size: string -> string -> thm list -> thm list -> local_theory -> local_theory
-  val register_size_global: string -> string -> thm list -> thm list -> theory -> theory
-  val size_of: Proof.context -> string -> (string * (thm list * thm list)) option
-  val size_of_global: theory -> string -> (string * (thm list * thm list)) option
+  val register_size: string -> string -> thm -> thm list -> thm list -> local_theory -> local_theory
+  val register_size_global: string -> string -> thm -> thm list -> thm list -> theory -> theory
+  val size_of: Proof.context -> string -> (string * (thm * thm list * thm list)) option
+  val size_of_global: theory -> string -> (string * (thm * thm list * thm list)) option
 end;
 
 structure BNF_LFP_Size : BNF_LFP_SIZE =
@@ -44,7 +44,7 @@
 
 structure Data = Generic_Data
 (
-  type T = (string * (thm list * thm list)) Symtab.table;
+  type T = (string * (thm * thm list * thm list)) Symtab.table;
   val empty = Symtab.empty;
   val extend = I
   fun merge data = Symtab.merge (K true) data;
@@ -63,19 +63,25 @@
       " must have type\n" ^ quote (Syntax.string_of_typ_global thy size_T))
   end;
 
-fun register_size T_name size_name size_simps size_gen_o_maps lthy =
+fun register_size T_name size_name overloaded_size_def size_simps size_gen_o_maps lthy =
   (check_size_type (Proof_Context.theory_of lthy) T_name size_name;
-   Context.proof_map (Data.map (Symtab.update (T_name, (size_name, (size_simps, size_gen_o_maps)))))
+   Context.proof_map (Data.map (Symtab.update
+       (T_name, (size_name, (overloaded_size_def, size_simps, size_gen_o_maps)))))
      lthy);
 
-fun register_size_global T_name size_name size_simps size_gen_o_maps thy =
+fun register_size_global T_name size_name overloaded_size_def size_simps size_gen_o_maps thy =
   (check_size_type thy T_name size_name;
-   Context.theory_map (Data.map (Symtab.update (T_name, (size_name, (size_simps,
-     size_gen_o_maps))))) thy);
+   Context.theory_map (Data.map (Symtab.update
+       (T_name, (size_name, (overloaded_size_def, size_simps, size_gen_o_maps)))))
+     thy);
 
 val size_of = Symtab.lookup o Data.get o Context.Proof;
 val size_of_global = Symtab.lookup o Data.get o Context.Theory;
 
+fun all_overloaded_size_defs_of ctxt =
+  Symtab.fold (fn (_, (_, (overloaded_size_def, _, _))) => cons overloaded_size_def)
+    (Data.get (Context.Proof ctxt)) [];
+
 val size_gen_o_map_simps = @{thms inj_on_id snd_comp_apfst[unfolded apfst_def]};
 
 fun mk_size_gen_o_map_tac ctxt size_def rec_o_map inj_maps size_maps =
@@ -138,7 +144,7 @@
             pair (snd_const T)
           else if exists (exists_subtype_in (As @ Cs)) Ts then
             (case Symtab.lookup data s of
-              SOME (size_name, (_, size_gen_o_maps)) =>
+              SOME (size_name, (_, _, size_gen_o_maps)) =>
               let
                 val (args, size_gen_o_mapss') = fold_map mk_size_of_typ Ts [];
                 val size_T = map fastype_of args ---> mk_to_natT T;
@@ -229,9 +235,8 @@
       fun define_overloaded_size def_b lhs0 rhs lthy =
         let
           val Free (c, _) = Syntax.check_term lthy lhs0;
-          val (thm, lthy') = lthy
-            |> Local_Theory.define ((Binding.name c, NoSyn), ((def_b, []), rhs))
-            |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
+          val ((_, (_, thm)), lthy') = lthy
+            |> Local_Theory.define ((Binding.name c, NoSyn), ((def_b, []), rhs));
           val thy_ctxt = Proof_Context.init_global (Proof_Context.theory_of lthy');
           val thm' = singleton (Proof_Context.export lthy' thy_ctxt) thm;
         in (thm', lthy') end;
@@ -251,10 +256,6 @@
       val overloaded_size_defs' =
         map (mk_unabs_def 1 o (fn thm => thm RS meta_eq_to_obj_eq)) overloaded_size_defs;
 
-      val all_overloaded_size_defs = overloaded_size_defs @
-        (Spec_Rules.retrieve lthy0 @{const size ('a)}
-         |> map_filter (try (fn (Spec_Rules.Equational, (_, [thm])) => thm)));
-
       val nested_size_maps =
         map (mk_pointful lthy2) nested_size_gen_o_maps @ nested_size_gen_o_maps;
       val all_inj_maps =
@@ -270,23 +271,24 @@
       fun derive_overloaded_size_simp overloaded_size_def' simp0 =
         (trans OF [overloaded_size_def', simp0])
         |> unfold_thms lthy2 @{thms add_0_left add_0_right}
-        |> fold_thms lthy2 all_overloaded_size_defs;
+        |> fold_thms lthy2 (overloaded_size_defs @ all_overloaded_size_defs_of lthy2);
 
       val size_simpss = map2 (map o derive_size_simp) size_defs' rec_thmss;
       val size_simps = flat size_simpss;
       val overloaded_size_simpss =
         map2 (map o derive_overloaded_size_simp) overloaded_size_defs' size_simpss;
+      val overloaded_size_simps = flat overloaded_size_simpss;
       val size_thmss = map2 append size_simpss overloaded_size_simpss;
       val size_gen_thmss = size_simpss
       fun rhs_is_zero thm =
         let val Const (trueprop, _) $ (Const (eq, _) $ _ $ rhs) = Thm.prop_of thm in
-          trueprop = @{const_name Trueprop} andalso
-          eq = @{const_name HOL.eq} andalso
+          trueprop = @{const_name Trueprop} andalso eq = @{const_name HOL.eq} andalso
           rhs = HOLogic.zero
         end;
 
       val size_neq_thmss = @{map 3} (fn fp_sugar => fn size => fn size_thms =>
-        if exists rhs_is_zero size_thms then []
+        if exists rhs_is_zero size_thms then
+          []
         else
           let
             val (xs, _) = mk_Frees "x" (binder_types (fastype_of size)) lthy2;
@@ -299,7 +301,7 @@
                 (#exhaust (#ctr_sugar (#fp_ctr_sugar fp_sugar))) size_thms)
               |> single
               |> map Thm.close_derivation;
-          in thm end) fp_sugars overloaded_size_consts overloaded_size_simpss
+          in thm end) fp_sugars overloaded_size_consts overloaded_size_simpss;
 
       val ABs = As ~~ Bs;
       val g_names = variant_names num_As "g";
@@ -373,16 +375,20 @@
       val (noted, lthy3) =
         lthy2
         |> Spec_Rules.add Spec_Rules.Equational (size_consts, size_simps)
+        |> Spec_Rules.add Spec_Rules.Equational (overloaded_size_consts, overloaded_size_simps)
         |> Local_Theory.notes notes;
 
       val phi0 = substitute_noted_thm noted;
     in
       lthy3
       |> Local_Theory.declaration {syntax = false, pervasive = true}
-        (fn phi => Data.map (fold2 (fn T_name => fn Const (size_name, _) =>
-             Symtab.update (T_name, (size_name,
-               apply2 (map (Morphism.thm (phi0 $> phi))) (size_simps, flat size_gen_o_map_thmss))))
-           T_names size_consts))
+        (fn phi => Data.map (@{fold 3} (fn T_name => fn Const (size_name, _) =>
+            fn overloaded_size_def =>
+               let val morph = Morphism.thm (phi0 $> phi) in
+                 Symtab.update (T_name, (size_name, (morph overloaded_size_def, map morph size_simps,
+                   maps (map morph) size_gen_o_map_thmss)))
+               end)
+           T_names size_consts overloaded_size_defs))
     end
   | generate_datatype_size _ lthy = lthy;