optimized simple non-recursive datatypes by reusing 'case' for 'rec' constant
authorblanchet
Mon, 03 Mar 2014 12:48:20 +0100
changeset 55862 b458558cbcc2
parent 55861 0a8200e31474
child 55863 fa3a1ec69a1b
optimized simple non-recursive datatypes by reusing 'case' for 'rec' constant
src/HOL/BNF_LFP.thy
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML
--- a/src/HOL/BNF_LFP.thy	Mon Mar 03 12:48:20 2014 +0100
+++ b/src/HOL/BNF_LFP.thy	Mon Mar 03 12:48:20 2014 +0100
@@ -266,4 +266,7 @@
 
 datatype_new 'a F = F 'a
 
+primrec f where
+  "f (F x) = x"
+
 end
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Mon Mar 03 12:48:20 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Mon Mar 03 12:48:20 2014 +0100
@@ -57,35 +57,32 @@
   val transfer_gfp_sugar_thms: Proof.context -> gfp_sugar_thms -> gfp_sugar_thms
 
   val mk_co_iters_prelims: BNF_Util.fp_kind -> typ list list list -> typ list -> typ list ->
-    typ list -> typ list -> int list -> int list list -> term list list -> Proof.context ->
-    (term list list
-     * (typ list list * typ list list list list * term list list
-        * term list list list list) list option
-     * (string * term list * term list list
-        * ((term list list * term list list list) * typ list) list) option)
-    * Proof.context
+     typ list -> typ list -> int list -> int list list -> term list list -> Proof.context ->
+     (term list list * (typ list list * typ list list list list * term list list
+        * term list list list list) option
+      * (string * term list * term list list
+        * ((term list list * term list list list) * typ list)) option)
+     * Proof.context
   val repair_nullary_single_ctr: typ list list -> typ list list
   val mk_coiter_p_pred_types: typ list -> int list -> typ list list
   val mk_coiter_fun_arg_types: typ list list list -> typ list -> typ list -> typ list -> int list ->
     int list list -> term ->
     typ list list
     * (typ list list list list * typ list list list * typ list list list list * typ list)
-  val define_iters: string list ->
-    (typ list list * typ list list list list * term list list * term list list list list) list ->
-    (string -> binding) -> typ list -> typ list -> term list -> term list -> Proof.context ->
-    (term list * thm list) * Proof.context
-  val define_coiters: string list -> string * term list * term list list
-    * ((term list list * term list list list) * typ list) list ->
-    (string -> binding) -> typ list -> typ list -> term list -> term list -> Proof.context ->
+  val define_iter:
+    (typ list list * typ list list list list * term list list * term list list list list) option ->
+    (string -> binding) -> typ list -> typ list -> term list -> term -> Proof.context ->
     (term list * thm list) * Proof.context
+  val define_coiter: 'a * term list * term list list
+      * ((term list list * term list list list) * typ list) -> (string -> binding) -> 'b list ->
+    typ list -> term list -> term -> Proof.context -> (term list * thm list) * local_theory
   val derive_induct_iters_thms_for_types: BNF_Def.bnf list ->
-    (typ list list * typ list list list list * term list list * term list list list list) list ->
-    thm -> thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
-    typ list -> typ list list list -> thm list -> thm list -> thm list -> term list list ->
-    thm list list -> term list list -> thm list list -> local_theory -> lfp_sugar_thms
+     ('a * typ list list list list * term list list * 'b) option -> thm -> thm list list ->
+     BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list -> typ list ->
+     typ list list list -> thm list -> thm list -> thm list -> term list list -> thm list list ->
+     term list list -> thm list list -> Proof.context -> lfp_sugar_thms
   val derive_coinduct_coiters_thms_for_types: BNF_Def.bnf list ->
-    string * term list * term list list * ((term list list * term list list list)
-      * typ list) list ->
+    string * term list * term list list * ((term list list * term list list list) * typ list) ->
     thm -> thm list -> thm list -> thm list list -> BNF_Def.bnf list -> typ list -> typ list ->
     typ list -> typ list list list -> int list list -> int list list -> int list -> thm list ->
     thm list -> (thm -> thm) -> thm list list -> Ctr_Sugar.ctr_sugar list -> term list list ->
@@ -381,7 +378,7 @@
       |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
     val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
   in
-    ([(h_Tss, z_Tssss, hss, zssss)], lthy)
+    ((h_Tss, z_Tssss, hss, zssss), lthy)
   end;
 
 (*avoid "'a itself" arguments in coiterators*)
@@ -449,12 +446,13 @@
 
     val corec_args = mk_args sssss hssss h_Tsss;
   in
-    ((z, cs, cpss, [(corec_args, corec_types)]), lthy)
+    ((z, cs, cpss, (corec_args, corec_types)), lthy)
   end;
 
 fun mk_co_iters_prelims fp ctr_Tsss fpTs Cs absTs repTs ns mss xtor_co_iterss0 lthy =
   let
     val thy = Proof_Context.theory_of lthy;
+    val nn = length fpTs;
 
     val (xtor_co_iter_fun_Tss, xtor_co_iterss) =
       map (mk_co_iters thy fp fpTs Cs #> `(binder_fun_types o fastype_of o hd))
@@ -463,8 +461,11 @@
 
     val ((iters_args_types, coiters_args_types), lthy') =
       if fp = Least_FP then
-        mk_iters_args_types ctr_Tsss Cs absTs repTs ns mss xtor_co_iter_fun_Tss lthy
-        |>> (rpair NONE o SOME)
+        if nn = 1 andalso forall (forall (forall (not o exists_subtype_in fpTs))) ctr_Tsss then
+          ((NONE, NONE), lthy)
+        else
+          mk_iters_args_types ctr_Tsss Cs absTs repTs ns mss xtor_co_iter_fun_Tss lthy
+          |>> (rpair NONE o SOME)
       else
         mk_coiters_args_types ctr_Tsss Cs absTs repTs ns mss xtor_co_iter_fun_Tss lthy
         |>> (pair NONE o SOME);
@@ -501,49 +502,44 @@
     ((csts', defs'), lthy')
   end;
 
-fun define_iters iterNs iter_args_typess' mk_binding fpTs Cs reps ctor_iters lthy =
-  let
-    val nn = length fpTs;
-    val fpT = domain_type (snd (strip_typeN nn (fastype_of (co_rec_of ctor_iters))));
+fun define_iter NONE _ _ _ _ _ lthy = (([], []), lthy)
+  | define_iter (SOME (_, _, fss, xssss)) mk_binding fpTs Cs reps ctor_iter lthy =
+    let
+      val nn = length fpTs;
+      val (ctor_iter_absTs, fpT) = strip_typeN nn (fastype_of ctor_iter)
+        |>> map domain_type ||> domain_type;
 
-    fun generate_iter pre (_, _, fss, xssss) ctor_iter =
-      let val ctor_iter_absTs = map domain_type (fst (strip_typeN nn (fastype_of ctor_iter))) in
-        (mk_binding pre,
+      val binding_spec =
+        (mk_binding recN,
          fold_rev (fold_rev Term.lambda) fss (Term.list_comb (ctor_iter,
            map4 (fn ctor_iter_absT => fn rep => fn fs => fn xsss =>
                mk_case_absumprod ctor_iter_absT rep fs
                  (map (HOLogic.mk_tuple o map HOLogic.mk_tuple) xsss) (map flat_rec_arg_args xsss))
-             ctor_iter_absTs reps fss xssss)))
-      end;
-  in
-    define_co_iters Least_FP fpT Cs (map3 generate_iter iterNs iter_args_typess' ctor_iters) lthy
-  end;
+             ctor_iter_absTs reps fss xssss)));
+    in
+      define_co_iters Least_FP fpT Cs [binding_spec] lthy
+    end;
 
-fun define_coiters coiterNs (_, cs, cpss, coiter_args_typess') mk_binding fpTs Cs abss dtor_coiters
-    lthy =
+fun define_coiter (_, cs, cpss, ((pfss, cqfsss), f_absTs)) mk_binding fpTs Cs abss dtor_coiter lthy =
   let
     val nn = length fpTs;
-    val fpT = range_type (snd (strip_typeN nn (fastype_of (co_rec_of dtor_coiters))));
+    val fpT = range_type (snd (strip_typeN nn (fastype_of dtor_coiter)));
 
-    fun generate_coiter pre ((pfss, cqfsss), f_absTs) dtor_coiter =
-      (mk_binding pre,
+    fun generate_coiter dtor_coiter =
+      (mk_binding corecN,
        fold_rev (fold_rev Term.lambda) pfss (Term.list_comb (dtor_coiter,
          map5 mk_preds_getterss_join cs cpss f_absTs abss cqfsss)));
   in
-    define_co_iters Greatest_FP fpT Cs
-      (map3 generate_coiter coiterNs coiter_args_typess' dtor_coiters) lthy
+    define_co_iters Greatest_FP fpT Cs [generate_coiter dtor_coiter] lthy
   end;
 
-fun derive_induct_iters_thms_for_types pre_bnfs [rec_args_types] ctor_induct ctor_iter_thmss
+fun derive_induct_iters_thms_for_types pre_bnfs rec_args_typess ctor_induct ctor_iter_thmss
     nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss fp_abs_inverses fp_type_definitions abs_inverses
     ctrss ctr_defss iterss iter_defss lthy =
   let
     val iterss' = transpose iterss;
     val iter_defss' = transpose iter_defss;
 
-    val [recs] = iterss';
-    val [rec_defs] = iter_defss';
-
     val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
 
     val nn = length pre_bnfs;
@@ -689,14 +685,19 @@
         map2 (map2 prove) goalss tacss
       end;
 
-    val rec_thmss = mk_iter_thmss rec_args_types recs rec_defs (map co_rec_of ctor_iter_thmss);
+    val rec_thmss =
+      (case rec_args_typess of
+        SOME types =>
+        mk_iter_thmss types (the_single iterss') (the_single iter_defss')
+          (map co_rec_of ctor_iter_thmss)
+      | NONE => replicate nn []);
   in
     ((induct_thms, induct_thm, [induct_case_names_attr]),
      (rec_thmss, code_nitpicksimp_attrs @ simp_attrs))
   end;
 
 fun derive_coinduct_coiters_thms_for_types pre_bnfs (z, cs, cpss,
-      coiters_args_types as [((phss, cshsss), _)])
+      coiters_args_types as ((phss, cshsss), _))
     dtor_coinduct dtor_injects dtor_ctors dtor_coiter_thmss nesting_bnfs fpTs Cs Xs ctrXs_Tsss kss
     mss ns fp_abs_inverses abs_inverses mk_vimage2p ctr_defss (ctr_sugars : ctr_sugar list)
     coiterss coiter_defss export_args lthy =
@@ -825,7 +826,7 @@
     fun mk_maybe_not pos = not pos ? HOLogic.mk_not;
 
     val fcoiterss' as [hcorecs] =
-      map2 (fn (pfss, _) => map (lists_bmoc pfss)) (map fst coiters_args_types) coiterss';
+      map2 (fn (pfss, _) => map (lists_bmoc pfss)) [fst coiters_args_types] coiterss';
 
     val corec_thmss =
       let
@@ -1337,9 +1338,9 @@
         (wrap_ctrs
          #> derive_maps_sets_rels
          ##>>
-           (if fp = Least_FP then define_iters [recN] (the iters_args_types) mk_binding fpTs Cs reps
-           else define_coiters [corecN] (the coiters_args_types) mk_binding fpTs Cs abss)
-             [co_rec_of xtor_co_iters]
+           (if fp = Least_FP then define_iter iters_args_types mk_binding fpTs Cs reps
+           else define_coiter (the coiters_args_types) mk_binding fpTs Cs abss)
+             (co_rec_of xtor_co_iters)
          #> massage_res, lthy')
       end;
 
@@ -1357,12 +1358,16 @@
           (iterss, iter_defss)), lthy) =
       let
         val ((induct_thms, induct_thm, induct_attrs), (rec_thmss, iter_attrs)) =
-          derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
+          derive_induct_iters_thms_for_types pre_bnfs iters_args_types xtor_co_induct
             xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss abs_inverses
             type_definitions abs_inverses ctrss ctr_defss iterss iter_defss lthy;
 
         val induct_type_attr = Attrib.internal o K o Induct.induct_type;
 
+        val (iterss', rec_thmss') =
+          if iterss = [[]] then ([map #casex ctr_sugars], map #case_thms ctr_sugars)
+          else (iterss, rec_thmss);
+
         val simp_thmss =
           map6 mk_simp_thms ctr_sugars rec_thmss mapss rel_injects rel_distincts setss;
 
@@ -1377,11 +1382,14 @@
           |> massage_multi_notes;
       in
         lthy
-        |> Spec_Rules.add Spec_Rules.Equational (map co_rec_of iterss, flat rec_thmss)
+        |> (if is_some iters_args_types then
+              Spec_Rules.add Spec_Rules.Equational (map co_rec_of iterss, flat rec_thmss)
+            else
+              I)
         |> Local_Theory.notes (common_notes @ notes) |> snd
         |> register_fp_sugars Xs Least_FP pre_bnfs absT_infos nested_bnfs nesting_bnfs fp_res
-          ctrXs_Tsss ctr_defss ctr_sugars iterss mapss [induct_thm] (map single induct_thms)
-          (map single rec_thmss) (replicate nn []) (replicate nn [])
+          ctrXs_Tsss ctr_defss ctr_sugars iterss' mapss [induct_thm] (map single induct_thms)
+          (map single rec_thmss') (replicate nn []) (replicate nn [])
       end;
 
     fun derive_note_coinduct_coiters_thms_for_types
--- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Mon Mar 03 12:48:20 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Mon Mar 03 12:48:20 2014 +0100
@@ -247,17 +247,15 @@
 
         val ((co_iterss, co_iter_defss), lthy) =
           fold_map2 (fn b =>
-            if fp = Least_FP then
-              define_iters [recN] (the iters_args_types) (mk_binding b) fpTs Cs reps
-            else
-              define_coiters [corecN] (the coiters_args_types) (mk_binding b) fpTs Cs abss)
-            fp_bs (map (single o co_rec_of) xtor_co_iterss) lthy
+            if fp = Least_FP then define_iter iters_args_types (mk_binding b) fpTs Cs reps
+            else define_coiter (the coiters_args_types) (mk_binding b) fpTs Cs abss)
+            fp_bs (map co_rec_of xtor_co_iterss) lthy
           |>> split_list;
 
         val ((common_co_inducts, co_inductss, co_rec_thmss, disc_corec_thmss, sel_corec_thmsss),
              fp_sugar_thms) =
           if fp = Least_FP then
-            derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
+            derive_induct_iters_thms_for_types pre_bnfs iters_args_types xtor_co_induct
               xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss fp_abs_inverses
               fp_type_definitions abs_inverses ctrss ctr_defss co_iterss co_iter_defss lthy
             |> `(fn ((inducts, induct, _), (rec_thmss, _)) =>