refactored induction principle generation code, for reuse for nonuniform datatypes
authorblanchet
Fri, 16 Dec 2016 22:54:14 +0100
changeset 64576 ce8802dc3145
parent 64575 d44f0b714e13
child 64577 0288a566c966
child 64607 20f3dbfe4b24
refactored induction principle generation code, for reuse for nonuniform datatypes
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Dec 16 19:50:46 2016 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Dec 16 22:54:14 2016 +0100
@@ -100,6 +100,7 @@
     'a list
   val mk_ctor: typ list -> term -> term
   val mk_dtor: typ list -> term -> term
+  val mk_bnf_sets: BNF_Def.bnf -> string * term list
   val liveness_of_fp_bnf: int -> BNF_Def.bnf -> bool list
   val nesting_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
 
@@ -162,6 +163,11 @@
       * (((term list list * term list list * term list list list list * term list list list list)
           * term list list list) * typ list) -> (string -> binding) -> 'b list -> typ list ->
     term list -> term -> local_theory -> (term * thm) * local_theory
+  val mk_induct_raw_prem: Proof.context -> typ list list -> (string * term list) list -> term ->
+    term -> typ list -> typ list ->
+    term list * ((term * (term * term)) list * (int * term)) list * term
+  val finish_induct_prem: Proof.context -> int -> term list ->
+    term list * ((term * (term * term)) list * (int * term)) list * term -> term
   val mk_induct_attrs: term list list -> Token.src list
   val mk_coinduct_attrs: typ list -> term list list -> term list list -> int list list ->
     Token.src list * Token.src list
@@ -568,6 +574,19 @@
 val mk_ctor = mk_ctor_or_dtor range_type;
 val mk_dtor = mk_ctor_or_dtor domain_type;
 
+fun mk_bnf_sets bnf =
+  let
+    val Type (T_name, Us) = T_of_bnf bnf;
+    val lives = lives_of_bnf bnf;
+    val sets = sets_of_bnf bnf;
+    fun mk_set U =
+      (case find_index (curry (op =) U) lives of
+        ~1 => Term.dummy
+      | i => nth sets i);
+  in
+    (T_name, map mk_set Us)
+  end;
+
 fun mk_xtor_co_recs thy fp fpTs Cs ts0 =
   let
     val nn = length fpTs;
@@ -614,10 +633,10 @@
 fun define_ctrs_dtrs_for_type fp_b_name fpT ctor dtor ctor_dtor dtor_ctor n ks abs ctr_bindings
     ctr_mixfixes ctr_Tss lthy =
   let
-    val ctr_absT = domain_type (fastype_of ctor);
+    val ctor_absT = domain_type (fastype_of ctor);
 
     val (((w, xss), u'), _) = lthy
-      |> yield_singleton (mk_Frees "w") ctr_absT
+      |> yield_singleton (mk_Frees "w") ctor_absT
       ||>> mk_Freess "x" ctr_Tss
       ||>> yield_singleton Variable.variant_fixes fp_b_name;
 
@@ -631,13 +650,13 @@
         val vars = Variable.add_free_names lthy goal [];
       in
         Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, ...} =>
-          mk_ctor_iff_dtor_tac ctxt (map (SOME o Thm.ctyp_of lthy) [ctr_absT, fpT])
+          mk_ctor_iff_dtor_tac ctxt (map (SOME o Thm.ctyp_of lthy) [ctor_absT, fpT])
             (Thm.cterm_of lthy ctor) (Thm.cterm_of lthy dtor) ctor_dtor dtor_ctor)
         |> Thm.close_derivation
       end;
 
     val ctr_rhss =
-      map2 (fn k => fn xs => fold_rev Term.lambda xs (ctor $ mk_absumprod ctr_absT abs n k xs))
+      map2 (fn k => fn xs => fold_rev Term.lambda xs (ctor $ mk_absumprod ctor_absT abs n k xs))
         ks xss;
 
     val ((raw_ctrs, raw_ctr_defs), (lthy, lthy_old)) = lthy
@@ -1604,6 +1623,46 @@
          @{map 5} mk_preds_getterss_join cs cpss f_absTs abss cqgsss)))
   end;
 
+fun mk_induct_raw_prem_prems names_lthy Xss setss_fp_nesting (x as Free (s, Type (T_name, Ts0)))
+      (Type (_, Xs_Ts0)) =
+    (case AList.lookup (op =) setss_fp_nesting T_name of
+      NONE => []
+    | SOME raw_sets0 =>
+      let
+        val (Xs_Ts, (Ts, raw_sets)) =
+          filter (exists_subtype_in (flat Xss) o fst) (Xs_Ts0 ~~ (Ts0 ~~ raw_sets0))
+          |> split_list ||> split_list;
+        val sets = map (mk_set Ts0) raw_sets;
+        val (ys, names_lthy') = names_lthy |> mk_Frees s Ts;
+        val xysets = map (pair x) (ys ~~ sets);
+        val ppremss = map2 (mk_induct_raw_prem_prems names_lthy' Xss setss_fp_nesting) ys Xs_Ts;
+      in
+        flat (map2 (map o apfst o cons) xysets ppremss)
+      end)
+  | mk_induct_raw_prem_prems _ Xss _ (x as Free (_, Type _)) X =
+    [([], (find_index (fn Xs => member (op =) Xs X) Xss + 1, x))]
+  | mk_induct_raw_prem_prems _ _ _ _ _ = [];
+
+fun mk_induct_raw_prem names_lthy Xss setss_fp_nesting p ctr ctr_Ts ctrXs_Ts =
+  let
+    val (xs, names_lthy') = names_lthy |> mk_Frees "x" ctr_Ts;
+    val pprems =
+      flat (map2 (mk_induct_raw_prem_prems names_lthy' Xss setss_fp_nesting) xs ctrXs_Ts);
+  in (xs, pprems, HOLogic.mk_Trueprop (p $ Term.list_comb (ctr, xs))) end;
+
+fun close_induct_prem_prem nn ps xs t =
+  fold_rev Logic.all (map Free (drop (nn + length xs)
+    (rev (Term.add_frees t (map dest_Free xs @ map_filter (try dest_Free) ps))))) t;
+
+fun finish_induct_prem_prem lthy nn ps xs (xysets, (j, x)) =
+  close_induct_prem_prem nn ps xs (Logic.list_implies (map (fn (x', (y, set)) =>
+      mk_Trueprop_mem (y, set $ x')) xysets,
+    HOLogic.mk_Trueprop (enforce_type lthy domain_type (fastype_of x) (nth ps (j - 1)) $ x)));
+
+fun finish_induct_prem lthy nn ps (xs, raw_pprems, concl) =
+  fold_rev Logic.all xs (Logic.list_implies
+    (map (finish_induct_prem_prem lthy nn ps xs) raw_pprems, concl));
+
 fun postproc_co_induct ctxt nn prop prop_conj =
   Drule.zero_var_indexes
   #> `(conj_dests nn)
@@ -1675,72 +1734,25 @@
 
     val fp_b_names = map base_name_of_typ fpTs;
 
-    val ((((ps, ps'), xsss), us'), names_lthy) = lthy
-      |> mk_Frees' "P" (map mk_pred1T fpTs)
+    val (((ps, xsss), us'), names_lthy) = lthy
+      |> mk_Frees "P" (map mk_pred1T fpTs)
       ||>> mk_Freesss "x" ctr_Tsss
       ||>> Variable.variant_fixes fp_b_names;
 
     val us = map2 (curry Free) us' fpTs;
 
-    fun mk_sets bnf =
-      let
-        val Type (T_name, Us) = T_of_bnf bnf;
-        val lives = lives_of_bnf bnf;
-        val sets = sets_of_bnf bnf;
-        fun mk_set U =
-          (case find_index (curry (op =) U) lives of
-            ~1 => Term.dummy
-          | i => nth sets i);
-      in
-        (T_name, map mk_set Us)
-      end;
-
-    val setss_fp_nesting = map mk_sets fp_nesting_bnfs;
+    val setss_fp_nesting = map mk_bnf_sets fp_nesting_bnfs;
 
     val (induct_thms, induct_thm) =
       let
-        fun mk_raw_prem_prems _ (x as Free (_, Type _)) (X as TFree _) =
-            [([], (find_index (curry (op =) X) Xs + 1, x))]
-          | mk_raw_prem_prems names_lthy (x as Free (s, Type (T_name, Ts0))) (Type (_, Xs_Ts0)) =
-            (case AList.lookup (op =) setss_fp_nesting T_name of
-              NONE => []
-            | SOME raw_sets0 =>
-              let
-                val (Xs_Ts, (Ts, raw_sets)) =
-                  filter (exists_subtype_in Xs o fst) (Xs_Ts0 ~~ (Ts0 ~~ raw_sets0))
-                  |> split_list ||> split_list;
-                val sets = map (mk_set Ts0) raw_sets;
-                val (ys, names_lthy') = names_lthy |> mk_Frees s Ts;
-                val xysets = map (pair x) (ys ~~ sets);
-                val ppremss = map2 (mk_raw_prem_prems names_lthy') ys Xs_Ts;
-              in
-                flat (map2 (map o apfst o cons) xysets ppremss)
-              end)
-          | mk_raw_prem_prems _ _ _ = [];
-
-        fun close_prem_prem xs t =
-          fold_rev Logic.all (map Free (drop (nn + length xs)
-            (rev (Term.add_frees t (map dest_Free xs @ ps'))))) t;
-
-        fun mk_prem_prem xs (xysets, (j, x)) =
-          close_prem_prem xs (Logic.list_implies (map (fn (x', (y, set)) =>
-              mk_Trueprop_mem (y, set $ x')) xysets,
-            HOLogic.mk_Trueprop (nth ps (j - 1) $ x)));
-
-        fun mk_raw_prem phi ctr ctr_Ts ctrXs_Ts =
-          let
-            val (xs, names_lthy') = names_lthy |> mk_Frees "x" ctr_Ts;
-            val pprems = flat (map2 (mk_raw_prem_prems names_lthy') xs ctrXs_Ts);
-          in (xs, pprems, HOLogic.mk_Trueprop (phi $ Term.list_comb (ctr, xs))) end;
-
-        fun mk_prem (xs, raw_pprems, concl) =
-          fold_rev Logic.all xs (Logic.list_implies (map (mk_prem_prem xs) raw_pprems, concl));
-
-        val raw_premss = @{map 4} (@{map 3} o mk_raw_prem) ps ctrss ctr_Tsss ctrXs_Tsss;
-
+        val raw_premss = @{map 4} (@{map 3}
+            o mk_induct_raw_prem names_lthy (map single Xs) setss_fp_nesting)
+          ps ctrss ctr_Tsss ctrXs_Tsss;
+        val concl =
+          HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry (op $)) ps us));
         val goal =
-          Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
-            HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry (op $)) ps us)));
+          Library.foldr (Logic.list_implies o apfst (map (finish_induct_prem lthy nn ps)))
+            (raw_premss, concl);
         val vars = Variable.add_free_names lthy goal [];
 
         val kksss = map (map (map (fst o snd) o #2)) raw_premss;