allow predicator instead of map function in 'primrec'
authorblanchet
Wed, 17 Feb 2016 11:54:34 +0100
changeset 62326 3cf7a067599c
parent 62325 7e4d31eefe60
child 62327 112eefe85ff0
allow predicator instead of map function in 'primrec'
src/HOL/Nat.thy
src/HOL/Tools/BNF/bnf_def.ML
src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML
src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML
--- a/src/HOL/Nat.thy	Tue Feb 16 22:28:19 2016 +0100
+++ b/src/HOL/Nat.thy	Wed Feb 17 11:54:34 2016 +0100
@@ -167,7 +167,7 @@
 setup \<open>
 let
   fun basic_lfp_sugars_of _ [@{typ nat}] _ _ ctxt =
-      ([], [0], [nat_basic_lfp_sugar], [], [], TrueI (*dummy*), [], false, ctxt)
+      ([], [0], [nat_basic_lfp_sugar], [], [], [], TrueI (*dummy*), [], false, ctxt)
     | basic_lfp_sugars_of bs arg_Ts callers callssss ctxt =
       BNF_LFP_Rec_Sugar.default_basic_lfp_sugars_of bs arg_Ts callers callssss ctxt;
 in
--- a/src/HOL/Tools/BNF/bnf_def.ML	Tue Feb 16 22:28:19 2016 +0100
+++ b/src/HOL/Tools/BNF/bnf_def.ML	Wed Feb 17 11:54:34 2016 +0100
@@ -115,6 +115,7 @@
   val wit_thmss_of_bnf: bnf -> thm list list
 
   val mk_map: int -> typ list -> typ list -> term -> term
+  val mk_pred: typ list -> term -> term
   val mk_rel: int -> typ list -> typ list -> term -> term
   val build_map: Proof.context -> typ list -> (typ * typ -> term) -> typ * typ -> term
   val build_rel: (string * (int * term)) list -> Proof.context -> typ list -> (typ * typ -> term) ->
@@ -712,6 +713,11 @@
     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   end;
 
+fun mk_pred Ts t =
+  let val Type (_, Ts0) = domain_type (body_fun_type (fastype_of t)) in
+    Term.subst_atomic_types (Ts0 ~~ Ts) t
+  end;
+
 fun mk_rel live Ts Us t =
   let val [Type (_, Ts0), Type (_, Us0)] = binder_types (snd (strip_typeN live (fastype_of t))) in
     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
--- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Tue Feb 16 22:28:19 2016 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Wed Feb 17 11:54:34 2016 +0100
@@ -10,6 +10,7 @@
   val unfold_lets_splits: term -> term
   val unfold_splits_lets: term -> term
   val dest_map: Proof.context -> string -> term -> term * term list
+  val dest_pred: Proof.context -> string -> term -> term * term list
 
   val mutualize_fp_sugars: (string -> bool) -> BNF_Util.fp_kind -> int list -> binding list ->
     typ list -> term list -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
@@ -78,27 +79,29 @@
   | unfold_splits_lets (Abs (s, T, t)) = Abs (s, T, unfold_splits_lets t)
   | unfold_splits_lets t = unfold_lets_splits t;
 
-fun mk_map_pattern ctxt s =
+fun dest_either_map_or_pred map_or_pred_of_bnf ctxt T_name call =
   let
-    val bnf = the (bnf_of ctxt s);
-    val mapx = map_of_bnf bnf;
+    val bnf = the (bnf_of ctxt T_name);
+    val const0 = map_or_pred_of_bnf bnf;
     val live = live_of_bnf bnf;
-    val (f_Ts, _) = strip_typeN live (fastype_of mapx);
-    val fs = map_index (fn (i, T) => Var (("?f", i), T)) f_Ts;
+    val (f_Ts, _) = strip_typeN live (fastype_of const0);
+    val fs = map_index (fn (i, T) => Var (("f", i), T)) f_Ts;
+    val pat = betapplys (const0, fs);
+    val (_, tenv) = fo_match ctxt call pat;
   in
-    (mapx, betapplys (mapx, fs))
+    (const0, Vartab.fold_rev (fn (_, (_, f)) => cons f) tenv [])
   end;
 
-fun dest_map ctxt s call =
-  let
-    val (map0, pat) = mk_map_pattern ctxt s;
-    val (_, tenv) = fo_match ctxt call pat;
-  in
-    (map0, Vartab.fold_rev (fn (_, (_, f)) => cons f) tenv [])
-  end;
+val dest_map = dest_either_map_or_pred map_of_bnf;
+val dest_pred = dest_either_map_or_pred pred_of_bnf;
 
-fun dest_abs_or_applied_map _ _ (Abs (_, _, t)) = (Term.dummy, [t])
-  | dest_abs_or_applied_map ctxt s (t1 $ _) = dest_map ctxt s t1;
+fun dest_map_or_pred ctxt T_name call =
+  (case try (dest_map ctxt T_name) call of
+    SOME res => res
+  | NONE => dest_pred ctxt T_name call);
+
+fun dest_abs_or_applied_map_or_pred _ _ (Abs (_, _, t)) = (Term.dummy, [t])
+  | dest_abs_or_applied_map_or_pred ctxt s (t1 $ _) = dest_map_or_pred ctxt s t1;
 
 fun map_partition f xs =
   fold_rev (fn x => fn (ys, (good, bad)) =>
@@ -199,7 +202,7 @@
     and freeze_fpTs_call kk fpT calls (T as Type (s, _)) =
         (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
           ([], _) =>
-          (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
+          (case map_partition (try (snd o dest_abs_or_applied_map_or_pred no_defs_lthy s)) calls of
             ([], _) => freeze_fpTs_mutual_call kk fpT calls T
           | callsp => freeze_fpTs_map kk fpT callsp T)
         | callsp => freeze_fpTs_map kk fpT callsp T)
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Tue Feb 16 22:28:19 2016 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Wed Feb 17 11:54:34 2016 +0100
@@ -28,6 +28,7 @@
     {recx: term,
      fp_nesting_map_ident0s: thm list,
      fp_nesting_map_comps: thm list,
+     fp_nesting_pred_maps: thm list,
      ctr_specs: rec_ctr_spec list}
 
   type basic_lfp_sugar =
@@ -44,16 +45,16 @@
      is_new_datatype: Proof.context -> string -> bool,
      basic_lfp_sugars_of: binding list -> typ list -> term list ->
        (term * term list list) list list -> local_theory ->
-       typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * Token.src list
-       * bool * local_theory,
+       typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
+       * Token.src list * bool * local_theory,
      rewrite_nested_rec_call: (Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
        term -> term -> term -> term) option};
 
   val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
   val default_basic_lfp_sugars_of: binding list -> typ list -> term list ->
     (term * term list list) list list -> local_theory ->
-    typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * Token.src list * bool
-    * local_theory
+    typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
+    * Token.src list * bool * local_theory
   val rec_specs_of: binding list -> typ list -> typ list -> term list ->
     (term * term list list) list list -> local_theory ->
     (bool * rec_spec list * typ list * thm * thm list * Token.src list * typ list) * local_theory
@@ -117,6 +118,7 @@
   {recx: term,
    fp_nesting_map_ident0s: thm list,
    fp_nesting_map_comps: thm list,
+   fp_nesting_pred_maps: thm list,
    ctr_specs: rec_ctr_spec list};
 
 type basic_lfp_sugar =
@@ -133,8 +135,8 @@
    is_new_datatype: Proof.context -> string -> bool,
    basic_lfp_sugars_of: binding list -> typ list -> term list ->
      (term * term list list) list list -> local_theory ->
-     typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * Token.src list * bool
-     * local_theory,
+     typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
+     * Token.src list * bool * local_theory,
    rewrite_nested_rec_call: (Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
      term -> term -> term -> term) option};
 
@@ -172,7 +174,7 @@
         {T = T, fp_res_index = 0, C = C, fun_arg_Tsss = fun_arg_Tsss, ctr_sugar = ctr_sugar,
          recx = casex, rec_thms = case_thms};
     in
-      ([], [0], [basic_lfp_sugar], [], [], TrueI, [], false, ctxt)
+      ([], [0], [basic_lfp_sugar], [], [], [], TrueI (*dummy*), [], false, ctxt)
     end
   | default_basic_lfp_sugars_of _ _ _ _ _ = error "Unsupported mutual recursion at this stage";
 
@@ -199,7 +201,7 @@
     val thy = Proof_Context.theory_of lthy0;
 
     val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, fp_nesting_map_ident0s, fp_nesting_map_comps,
-         common_induct, induct_attrs, n2m, lthy) =
+         fp_nesting_pred_maps, common_induct, induct_attrs, n2m, lthy) =
       basic_lfp_sugars_of bs arg_Ts callers callssss0 lthy0;
 
     val perm_basic_lfp_sugars = sort (int_ord o apply2 #fp_res_index) basic_lfp_sugars;
@@ -259,6 +261,7 @@
         ({T, fp_res_index, ctr_sugar = {ctrs, ...}, recx, rec_thms, ...} : basic_lfp_sugar) =
       {recx = mk_co_rec thy Least_FP perm_Cs' (substAT T) recx,
        fp_nesting_map_ident0s = fp_nesting_map_ident0s, fp_nesting_map_comps = fp_nesting_map_comps,
+       fp_nesting_pred_maps = fp_nesting_pred_maps,
        ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms};
   in
     ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, common_induct, inducts,
@@ -492,10 +495,12 @@
     |> (fn [] => NONE | callss => SOME (ctr, callss))
   end;
 
-fun mk_primrec_tac ctxt num_extra_args map_ident0s map_comps fun_defs recx =
+fun mk_primrec_tac ctxt num_extra_args fp_nesting_map_ident0s fp_nesting_map_comps
+    fp_nesting_pred_maps fun_defs recx =
   unfold_thms_tac ctxt fun_defs THEN
   HEADGOAL (rtac ctxt (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
-  unfold_thms_tac ctxt (nested_simps ctxt @ map_ident0s @ map_comps) THEN
+  unfold_thms_tac ctxt (nested_simps ctxt @ fp_nesting_map_ident0s @ fp_nesting_map_comps @
+    fp_nesting_pred_maps) THEN
   HEADGOAL (rtac ctxt refl);
 
 fun prepare_primrec plugins nonexhaustives transfers fixes specs lthy0 =
@@ -541,8 +546,8 @@
 
     val defs = build_defs lthy nonexhaustives bs mxs funs_data rec_specs has_call;
 
-    fun prove def_thms ({ctr_specs, fp_nesting_map_ident0s, fp_nesting_map_comps, ...} : rec_spec)
-        (fun_data : eqn_data list) lthy' =
+    fun prove def_thms ({ctr_specs, fp_nesting_map_ident0s, fp_nesting_map_comps,
+        fp_nesting_pred_maps, ...} : rec_spec) (fun_data : eqn_data list) lthy' =
       let
         val js =
           find_indices (op = o apply2 (fn {fun_name, ctr, ...} => (fun_name, ctr)))
@@ -556,7 +561,7 @@
               Goal.prove_sorry lthy' [] [] user_eqn
                 (fn {context = ctxt, prems = _} =>
                   mk_primrec_tac ctxt num_extra_args fp_nesting_map_ident0s fp_nesting_map_comps
-                    def_thms rec_thm)
+                    fp_nesting_pred_maps def_thms rec_thm)
               |> Thm.close_derivation);
       in
         ((js, simps), lthy')
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML	Tue Feb 16 22:28:19 2016 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML	Wed Feb 17 11:54:34 2016 +0100
@@ -36,7 +36,7 @@
    recx = recx, rec_thms = rec_thms};
 
 fun basic_lfp_sugars_of _ [@{typ nat}] _ _ lthy =
-    ([], [0], [nat_basic_lfp_sugar], [], [], TrueI (*dummy*), [], false, lthy)
+    ([], [0], [nat_basic_lfp_sugar], [], [], [], TrueI (*dummy*), [], false, lthy)
   | basic_lfp_sugars_of bs arg_Ts callers callssss0 lthy0 =
     let
       val ((missing_arg_Ts, perm0_kks,
@@ -63,10 +63,11 @@
 
       val fp_nesting_map_ident0s = map map_ident0_of_bnf fp_nesting_bnfs;
       val fp_nesting_map_comps = map map_comp_of_bnf fp_nesting_bnfs;
+      val fp_nesting_pred_maps = map pred_map_of_bnf fp_nesting_bnfs;
     in
       (missing_arg_Ts, perm0_kks, @{map 3} basic_lfp_sugar_of Cs fun_arg_Tssss fp_sugars,
-       fp_nesting_map_ident0s, fp_nesting_map_comps, common_induct, induct_attrs,
-       is_some lfp_sugar_thms, lthy)
+       fp_nesting_map_ident0s, fp_nesting_map_comps, fp_nesting_pred_maps, common_induct,
+       induct_attrs, is_some lfp_sugar_thms, lthy)
     end;
 
 exception NO_MAP of term;
@@ -108,7 +109,16 @@
           in
             Term.list_comb (map', fs')
           end
-        | NONE => raise NO_MAP t)
+        | NONE =>
+          (case try (dest_pred ctxt s) t of
+            SOME (pred0, fs) =>
+            let
+              val pred' = mk_pred Us pred0;
+              val fs' = map_flattened_map_args ctxt s (@{map 3} massage_map_or_map_arg Us Ts) fs;
+            in
+              Term.list_comb (pred', fs')
+            end
+          | NONE => raise NO_MAP t))
       | massage_map _ _ t = raise NO_MAP t
     and massage_map_or_map_arg U T t =
       if T = U then
--- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Tue Feb 16 22:28:19 2016 +0100
+++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Wed Feb 17 11:54:34 2016 +0100
@@ -897,11 +897,11 @@
               val nontriv_disc_thmss =
                 map2 (fn b => if is_disc_binding_valid b then I else K []) disc_bindings disc_thmss;
 
-              fun is_discI_boring b =
+              fun is_discI_triv b =
                 (n = 1 andalso Binding.is_empty b) orelse Binding.eq_name (b, equal_binding);
 
               val nontriv_discI_thms =
-                flat (map2 (fn b => if is_discI_boring b then K [] else single) disc_bindings
+                flat (map2 (fn b => if is_discI_triv b then K [] else single) disc_bindings
                   discI_thms);
 
               val (distinct_disc_thms, (distinct_disc_thmsss', distinct_disc_thmsss)) =