made old-style 'size' interpretation hook more robust in case 'size' is already specified (either by the user or by the new datatype package)
authorblanchet
Wed, 23 Apr 2014 10:23:26 +0200
changeset 56639 c9d6b581bd3b
parent 56638 092a306bcc3d
child 56640 0a35354137a5
made old-style 'size' interpretation hook more robust in case 'size' is already specified (either by the user or by the new datatype package)
src/HOL/BNF_LFP.thy
src/HOL/Tools/BNF/bnf_lfp_size.ML
src/HOL/Tools/Function/size.ML
--- a/src/HOL/BNF_LFP.thy	Wed Apr 23 10:23:26 2014 +0200
+++ b/src/HOL/BNF_LFP.thy	Wed Apr 23 10:23:26 2014 +0200
@@ -201,63 +201,4 @@
 
 hide_fact (open) id_transfer
 
-datatype_new x = X nat
-thm x.size
-
-datatype_new 'a l = N | C 'a "'a l"
-thm l.size
-thm l.size_map
-thm size_l_def size_l_overloaded_def
-
-datatype_new
-  'a tl = TN | TC "'a mt" "'a tl" and
-  'a mt = MT 'a "'a tl"
-
-thm size_tl_def size_tl_overloaded_def
-thm size_mt_def size_mt_overloaded_def
-
-datatype_new 'a t = T 'a "'a t l"
-thm t.size
-
-lemma size_l_cong: "(ALL x : set_l t. f x = g x) --> size_l f t = size_l g t"
-  apply (induct_tac t)
-  apply (simp only: l.size simp_thms)
-  apply (simp add: l.set l.size simp_thms)
-  done
-
-lemma t_size_map_t: "size_t g (map_t f t) = size_t (g \<circ> f) t"
-  apply (rule t.induct)
-  apply (simp_all only: t.map t.size comp_def l.size_map)
-  apply (auto intro: size_l_cong)
-  apply (subst size_l_cong[rule_format], assumption)
-  apply (rule refl)
-  done
-
-
-thm t.size
-
-lemmas size_t_def' =
-  size_t_def[THEN meta_eq_to_obj_eq, THEN fun_cong, THEN fun_cong]
-
-thm trans[OF size_t_def' t.rec(1), unfolded l.size_map snd_o_convol, folded size_t_def']
-
-lemma "size_t f (T x ts) = f x + size_l (size_t f) ts + Suc 0"
-  unfolding size_t_def t.rec l.size_map snd_o_convol
-  by rule
-
-
-lemma "       (\<And>x2aa. x2aa \<in> set_l x2a \<Longrightarrow>
-                size_t f1 (map_t g1 x2aa) = size_t (f1 \<circ> g1) x2aa) \<Longrightarrow>
-       f1 (g1 x1a) +
-       size_l snd (map_l (\<lambda>t. (t, size_t f1 t)) (map_l (map_t g1) x2a)) +
-       Suc 0 =
-       f1 (g1 x1a) + size_l snd (map_l (\<lambda>t. (t, size_t (\<lambda>x1. f1 (g1 x1)) t)) x2a) +
-       Suc 0"
-apply (simp only: l.size_map comp_def snd_conv t.size_map snd_o_convol)
-
-thm size_t_def size_t_overloaded_def
-
-xdatatype_new ('a, 'b, 'c) x = XN 'c | XC 'a "('a, 'b, 'c) x"
-thm size_x_def size_x_overloaded_def
-
 end
--- a/src/HOL/Tools/BNF/bnf_lfp_size.ML	Wed Apr 23 10:23:26 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_size.ML	Wed Apr 23 10:23:26 2014 +0200
@@ -236,6 +236,7 @@
       T_names gen_size_names)
   end;
 
-val _ = Theory.setup (fp_sugar_interpretation generate_size);
+(* FIXME: get rid of "perhaps o try" once the code is stable *)
+val _ = Theory.setup (fp_sugar_interpretation (perhaps o try o generate_size));
 
 end;
--- a/src/HOL/Tools/Function/size.ML	Wed Apr 23 10:23:26 2014 +0200
+++ b/src/HOL/Tools/Function/size.ML	Wed Apr 23 10:23:26 2014 +0200
@@ -59,162 +59,170 @@
     val {descr, rec_names, rec_rewrites, induct, ...} = info;
     val l = length new_type_names;
     val descr' = List.take (descr, l);
-    val (rec_names1, rec_names2) = chop l rec_names;
-    val recTs = Datatype_Aux.get_rec_types descr;
-    val (recTs1, recTs2) = chop l recTs;
-    val (_, (_, paramdts, _)) :: _ = descr;
-    val paramTs = map (Datatype_Aux.typ_of_dtyp descr) paramdts;
-    val ((param_size_fs, param_size_fTs), f_names) = paramTs |>
-      map (fn T as TFree (s, _) =>
-        let
-          val name = "f" ^ unprefix "'" s;
-          val U = T --> HOLogic.natT
-        in
-          (((s, Free (name, U)), U), name)
-        end) |> split_list |>> split_list;
-    val param_size = AList.lookup op = param_size_fs;
-
-    val extra_rewrites = descr |> map (#1 o snd) |> distinct op = |>
-      map_filter (Option.map snd o lookup_size thy) |> flat;
-    val extra_size = Option.map fst o lookup_size thy;
-
-    val (((size_names, size_fns), def_names), def_names') =
-      recTs1 |> map (fn T as Type (s, _) =>
-        let
-          val s' = Long_Name.base_name s ^ "_size";
-          val s'' = Sign.full_bname thy s';
-        in
-          (s'',
-           (list_comb (Const (s'', param_size_fTs @ [T] ---> HOLogic.natT),
-              map snd param_size_fs),
-            (s' ^ "_def", s' ^ "_overloaded_def")))
-        end) |> split_list ||>> split_list ||>> split_list;
-    val overloaded_size_fns = map HOLogic.size_const recTs1;
-
-    (* instantiation for primrec combinator *)
-    fun size_of_constr b size_ofp ((_, cargs), (_, cargs')) =
-      let
-        val Ts = map (Datatype_Aux.typ_of_dtyp descr) cargs;
-        val k = length (filter Datatype_Aux.is_rec_type cargs);
-        val (ts, _, _) = fold_rev (fn ((dt, dt'), T) => fn (us, i, j) =>
-          if Datatype_Aux.is_rec_type dt then (Bound i :: us, i + 1, j + 1)
-          else
-            (if b andalso is_poly thy dt' then
-               case size_of_type (K NONE) extra_size size_ofp T of
-                 NONE => us | SOME sz => sz $ Bound j :: us
-             else us, i, j + 1))
-              (cargs ~~ cargs' ~~ Ts) ([], 0, k);
-        val t =
-          if null ts andalso (not b orelse not (exists (is_poly thy) cargs'))
-          then HOLogic.zero
-          else foldl1 plus (ts @ [HOLogic.Suc_zero])
-      in
-        fold_rev (fn T => fn t' => Abs ("x", T, t')) (Ts @ replicate k HOLogic.natT) t
-      end;
-
-    val fs = maps (fn (_, (name, _, constrs)) =>
-      map (size_of_constr true param_size) (constrs ~~ constrs_of thy name)) descr;
-    val fs' = maps (fn (n, (name, _, constrs)) =>
-      map (size_of_constr (l <= n) (K NONE)) (constrs ~~ constrs_of thy name)) descr;
-    val fTs = map fastype_of fs;
-
-    val (rec_combs1, rec_combs2) = chop l (map (fn (T, rec_name) =>
-      Const (rec_name, fTs @ [T] ---> HOLogic.natT))
-        (recTs ~~ rec_names));
-
-    fun define_overloaded (def_name, eq) lthy =
-      let
-        val (Free (c, _), rhs) = (Logic.dest_equals o Syntax.check_term lthy) eq;
-        val (thm, lthy') = lthy
-          |> Local_Theory.define ((Binding.name c, NoSyn), ((Binding.name def_name, []), rhs))
-          |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
-        val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
-        val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
-      in (thm', lthy') end;
-
-    val ((size_def_thms, size_def_thms'), thy') =
+    val tycos = map (#1 o snd) descr';
+  in
+    if forall (fn tyco => can (Sign.arity_sorts thy tyco) [HOLogic.class_size]) tycos then
+      (* nothing to do -- the "size" function is already defined *)
       thy
-      |> Sign.add_consts (map (fn (s, T) =>
-           (Binding.name (Long_Name.base_name s), param_size_fTs @ [T] ---> HOLogic.natT, NoSyn))
-           (size_names ~~ recTs1))
-      |> Global_Theory.add_defs false
-        (map (Thm.no_attributes o apsnd (Logic.mk_equals o apsnd (app fs)))
-           (map Binding.name def_names ~~ (size_fns ~~ rec_combs1)))
-      ||> Class.instantiation
-           (map (#1 o snd) descr', map dest_TFree paramTs, [HOLogic.class_size])
-      ||>> fold_map define_overloaded
-        (def_names' ~~ map Logic.mk_equals (overloaded_size_fns ~~ map (app fs') rec_combs1))
-      ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
-      ||> Local_Theory.exit_global;
-
-    val ctxt = Proof_Context.init_global thy';
-
-    val simpset1 =
-      put_simpset HOL_basic_ss ctxt addsimps @{thm Nat.add_0} :: @{thm Nat.add_0_right} ::
-        size_def_thms @ size_def_thms' @ rec_rewrites @ extra_rewrites;
-    val xs = map (fn i => "x" ^ string_of_int i) (1 upto length recTs2);
-
-    fun mk_unfolded_size_eq tab size_ofp fs (p as (x, T), r) =
-      HOLogic.mk_eq (app fs r $ Free p,
-        the (size_of_type tab extra_size size_ofp T) $ Free p);
-
-    fun prove_unfolded_size_eqs size_ofp fs =
-      if null recTs2 then []
-      else Datatype_Aux.split_conj_thm (Goal.prove_sorry ctxt xs []
-        (HOLogic.mk_Trueprop (Datatype_Aux.mk_conj (replicate l @{term True} @
-           map (mk_unfolded_size_eq (AList.lookup op =
-               (new_type_names ~~ map (app fs) rec_combs1)) size_ofp fs)
-             (xs ~~ recTs2 ~~ rec_combs2))))
-        (fn _ => (Datatype_Aux.ind_tac induct xs THEN_ALL_NEW asm_simp_tac simpset1) 1));
-
-    val unfolded_size_eqs1 = prove_unfolded_size_eqs param_size fs;
-    val unfolded_size_eqs2 = prove_unfolded_size_eqs (K NONE) fs';
-
-    (* characteristic equations for size functions *)
-    fun gen_mk_size_eq p size_of size_ofp size_const T (cname, cargs) =
+    else
       let
-        val Ts = map (Datatype_Aux.typ_of_dtyp descr) cargs;
-        val tnames = Name.variant_list f_names (Datatype_Prop.make_tnames Ts);
-        val ts = map_filter (fn (sT as (s, T), dt) =>
-          Option.map (fn sz => sz $ Free sT)
-            (if p dt then size_of_type size_of extra_size size_ofp T
-             else NONE)) (tnames ~~ Ts ~~ cargs)
-      in
-        HOLogic.mk_Trueprop (HOLogic.mk_eq
-          (size_const $ list_comb (Const (cname, Ts ---> T),
-             map2 (curry Free) tnames Ts),
-           if null ts then HOLogic.zero
-           else foldl1 plus (ts @ [HOLogic.Suc_zero])))
-      end;
+        val (rec_names1, rec_names2) = chop l rec_names;
+        val recTs = Datatype_Aux.get_rec_types descr;
+        val (recTs1, recTs2) = chop l recTs;
+        val (_, (_, paramdts, _)) :: _ = descr;
+        val paramTs = map (Datatype_Aux.typ_of_dtyp descr) paramdts;
+        val ((param_size_fs, param_size_fTs), f_names) = paramTs |>
+          map (fn T as TFree (s, _) =>
+            let
+              val name = "f" ^ unprefix "'" s;
+              val U = T --> HOLogic.natT
+            in
+              (((s, Free (name, U)), U), name)
+            end) |> split_list |>> split_list;
+        val param_size = AList.lookup op = param_size_fs;
+
+        val extra_rewrites = descr |> map (#1 o snd) |> distinct op = |>
+          map_filter (Option.map snd o lookup_size thy) |> flat;
+        val extra_size = Option.map fst o lookup_size thy;
+
+        val (((size_names, size_fns), def_names), def_names') =
+          recTs1 |> map (fn T as Type (s, _) =>
+            let
+              val s' = Long_Name.base_name s ^ "_size";
+              val s'' = Sign.full_bname thy s';
+            in
+              (s'',
+               (list_comb (Const (s'', param_size_fTs @ [T] ---> HOLogic.natT),
+                  map snd param_size_fs),
+                (s' ^ "_def", s' ^ "_overloaded_def")))
+            end) |> split_list ||>> split_list ||>> split_list;
+        val overloaded_size_fns = map HOLogic.size_const recTs1;
 
-    val simpset2 =
-      put_simpset HOL_basic_ss ctxt
-        addsimps (rec_rewrites @ size_def_thms @ unfolded_size_eqs1);
-    val simpset3 =
-      put_simpset HOL_basic_ss ctxt
-        addsimps (rec_rewrites @ size_def_thms' @ unfolded_size_eqs2);
+        (* instantiation for primrec combinator *)
+        fun size_of_constr b size_ofp ((_, cargs), (_, cargs')) =
+          let
+            val Ts = map (Datatype_Aux.typ_of_dtyp descr) cargs;
+            val k = length (filter Datatype_Aux.is_rec_type cargs);
+            val (ts, _, _) = fold_rev (fn ((dt, dt'), T) => fn (us, i, j) =>
+              if Datatype_Aux.is_rec_type dt then (Bound i :: us, i + 1, j + 1)
+              else
+                (if b andalso is_poly thy dt' then
+                   case size_of_type (K NONE) extra_size size_ofp T of
+                     NONE => us | SOME sz => sz $ Bound j :: us
+                 else us, i, j + 1))
+                  (cargs ~~ cargs' ~~ Ts) ([], 0, k);
+            val t =
+              if null ts andalso (not b orelse not (exists (is_poly thy) cargs'))
+              then HOLogic.zero
+              else foldl1 plus (ts @ [HOLogic.Suc_zero])
+          in
+            fold_rev (fn T => fn t' => Abs ("x", T, t')) (Ts @ replicate k HOLogic.natT) t
+          end;
+
+        val fs = maps (fn (_, (name, _, constrs)) =>
+          map (size_of_constr true param_size) (constrs ~~ constrs_of thy name)) descr;
+        val fs' = maps (fn (n, (name, _, constrs)) =>
+          map (size_of_constr (l <= n) (K NONE)) (constrs ~~ constrs_of thy name)) descr;
+        val fTs = map fastype_of fs;
+
+        val (rec_combs1, rec_combs2) = chop l (map (fn (T, rec_name) =>
+          Const (rec_name, fTs @ [T] ---> HOLogic.natT))
+            (recTs ~~ rec_names));
+
+        fun define_overloaded (def_name, eq) lthy =
+          let
+            val (Free (c, _), rhs) = (Logic.dest_equals o Syntax.check_term lthy) eq;
+            val (thm, lthy') = lthy
+              |> Local_Theory.define ((Binding.name c, NoSyn), ((Binding.name def_name, []), rhs))
+              |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
+            val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
+            val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
+          in (thm', lthy') end;
 
-    fun prove_size_eqs p size_fns size_ofp simpset =
-      maps (fn (((_, (_, _, constrs)), size_const), T) =>
-        map (fn constr => Drule.export_without_context (Goal.prove_sorry ctxt [] []
-          (gen_mk_size_eq p (AList.lookup op = (new_type_names ~~ size_fns))
-             size_ofp size_const T constr)
-          (fn _ => simp_tac simpset 1))) constrs)
-        (descr' ~~ size_fns ~~ recTs1);
+        val ((size_def_thms, size_def_thms'), thy') =
+          thy
+          |> Sign.add_consts (map (fn (s, T) => (Binding.name (Long_Name.base_name s),
+              param_size_fTs @ [T] ---> HOLogic.natT, NoSyn))
+            (size_names ~~ recTs1))
+          |> Global_Theory.add_defs false
+            (map (Thm.no_attributes o apsnd (Logic.mk_equals o apsnd (app fs)))
+               (map Binding.name def_names ~~ (size_fns ~~ rec_combs1)))
+          ||> Class.instantiation (tycos, map dest_TFree paramTs, [HOLogic.class_size])
+          ||>> fold_map define_overloaded
+            (def_names' ~~ map Logic.mk_equals (overloaded_size_fns ~~ map (app fs') rec_combs1))
+          ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
+          ||> Local_Theory.exit_global;
+
+        val ctxt = Proof_Context.init_global thy';
 
-    val size_eqns = prove_size_eqs (is_poly thy') size_fns param_size simpset2 @
-      prove_size_eqs Datatype_Aux.is_rec_type overloaded_size_fns (K NONE) simpset3;
+        val simpset1 =
+          put_simpset HOL_basic_ss ctxt addsimps @{thm Nat.add_0} :: @{thm Nat.add_0_right} ::
+            size_def_thms @ size_def_thms' @ rec_rewrites @ extra_rewrites;
+        val xs = map (fn i => "x" ^ string_of_int i) (1 upto length recTs2);
+
+        fun mk_unfolded_size_eq tab size_ofp fs (p as (x, T), r) =
+          HOLogic.mk_eq (app fs r $ Free p,
+            the (size_of_type tab extra_size size_ofp T) $ Free p);
+
+        fun prove_unfolded_size_eqs size_ofp fs =
+          if null recTs2 then []
+          else Datatype_Aux.split_conj_thm (Goal.prove_sorry ctxt xs []
+            (HOLogic.mk_Trueprop (Datatype_Aux.mk_conj (replicate l @{term True} @
+               map (mk_unfolded_size_eq (AList.lookup op =
+                   (new_type_names ~~ map (app fs) rec_combs1)) size_ofp fs)
+                 (xs ~~ recTs2 ~~ rec_combs2))))
+            (fn _ => (Datatype_Aux.ind_tac induct xs THEN_ALL_NEW asm_simp_tac simpset1) 1));
+
+        val unfolded_size_eqs1 = prove_unfolded_size_eqs param_size fs;
+        val unfolded_size_eqs2 = prove_unfolded_size_eqs (K NONE) fs';
 
-    val ([(_, size_thms)], thy'') = thy'
-      |> Global_Theory.note_thmss ""
-        [((Binding.name "size",
-            [Simplifier.simp_add, Nitpick_Simps.add,
-             Thm.declaration_attribute (fn thm => Context.mapping (Code.add_default_eqn thm) I)]),
-          [(size_eqns, [])])];
+        (* characteristic equations for size functions *)
+        fun gen_mk_size_eq p size_of size_ofp size_const T (cname, cargs) =
+          let
+            val Ts = map (Datatype_Aux.typ_of_dtyp descr) cargs;
+            val tnames = Name.variant_list f_names (Datatype_Prop.make_tnames Ts);
+            val ts = map_filter (fn (sT as (s, T), dt) =>
+              Option.map (fn sz => sz $ Free sT)
+                (if p dt then size_of_type size_of extra_size size_ofp T
+                 else NONE)) (tnames ~~ Ts ~~ cargs)
+          in
+            HOLogic.mk_Trueprop (HOLogic.mk_eq
+              (size_const $ list_comb (Const (cname, Ts ---> T),
+                 map2 (curry Free) tnames Ts),
+               if null ts then HOLogic.zero
+               else foldl1 plus (ts @ [HOLogic.Suc_zero])))
+          end;
+
+        val simpset2 =
+          put_simpset HOL_basic_ss ctxt
+            addsimps (rec_rewrites @ size_def_thms @ unfolded_size_eqs1);
+        val simpset3 =
+          put_simpset HOL_basic_ss ctxt
+            addsimps (rec_rewrites @ size_def_thms' @ unfolded_size_eqs2);
 
-  in
-    Data.map (fold (Symtab.update_new o apsnd (rpair size_thms))
-      (new_type_names ~~ size_names)) thy''
+        fun prove_size_eqs p size_fns size_ofp simpset =
+          maps (fn (((_, (_, _, constrs)), size_const), T) =>
+            map (fn constr => Drule.export_without_context (Goal.prove_sorry ctxt [] []
+              (gen_mk_size_eq p (AList.lookup op = (new_type_names ~~ size_fns))
+                 size_ofp size_const T constr)
+              (fn _ => simp_tac simpset 1))) constrs)
+            (descr' ~~ size_fns ~~ recTs1);
+
+        val size_eqns = prove_size_eqs (is_poly thy') size_fns param_size simpset2 @
+          prove_size_eqs Datatype_Aux.is_rec_type overloaded_size_fns (K NONE) simpset3;
+
+        val ([(_, size_thms)], thy'') = thy'
+          |> Global_Theory.note_thmss ""
+            [((Binding.name "size",
+                [Simplifier.simp_add, Nitpick_Simps.add,
+                 Thm.declaration_attribute (fn thm =>
+                   Context.mapping (Code.add_default_eqn thm) I)]),
+              [(size_eqns, [])])];
+
+      in
+        Data.map (fold (Symtab.update_new o apsnd (rpair size_thms))
+          (new_type_names ~~ size_names)) thy''
+      end
   end;
 
 fun add_size_thms config (new_type_names as name :: _) thy =