generate coinductive witnesses for codatatypes
authortraytel
Mon, 03 Sep 2012 17:57:34 +0200
changeset 49104 6defdacd595a
parent 49103 3caaa80f53a4
child 49105 a426099dc343
generate coinductive witnesses for codatatypes
src/HOL/Codatatype/Tools/bnf_gfp.ML
src/HOL/Codatatype/Tools/bnf_gfp_tactics.ML
--- a/src/HOL/Codatatype/Tools/bnf_gfp.ML	Mon Sep 03 17:56:39 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML	Mon Sep 03 17:57:34 2012 +0200
@@ -22,6 +22,34 @@
 open BNF_GFP_Util
 open BNF_GFP_Tactics
 
+datatype wit_tree = Leaf of int | Node of (int * int * int list) * wit_tree list;
+
+fun mk_tree_args (I, T) (I', Ts) = (sort_distinct int_ord (I @ I'), T :: Ts);
+
+fun finish Iss m seen i (nwit, I) =
+  let
+    val treess = map (fn j =>
+        if j < m orelse member (op =) seen j then [([j], Leaf j)]
+        else
+          map_index (finish Iss m (insert (op =) j seen) j) (nth Iss (j - m))
+          |> flat
+          |> minimize_wits)
+      I;
+  in
+    map (fn (I, t) => (I, Node ((i - m, nwit, filter (fn i => i < m) I), t)))
+      (fold_rev (map_product mk_tree_args) treess [([], [])])
+    |> minimize_wits
+  end;
+
+fun tree_to_fld_wit vars _ _ (Leaf j) = ([j], nth vars j)
+  | tree_to_fld_wit vars flds witss (Node ((i, nwit, I), subtrees)) =
+     (I, nth flds i $ (Term.list_comb (snd (nth (nth witss i) nwit),
+       map (snd o tree_to_fld_wit vars flds witss) subtrees)));
+
+fun tree_to_coind_wits _ (Leaf j) = []
+  | tree_to_coind_wits lwitss (Node ((i, nwit, I), subtrees)) =
+     ((i, I), nth (nth lwitss i) nwit) :: maps (tree_to_coind_wits lwitss) subtrees;
+
 (*all bnfs have the same lives*)
 fun bnf_gfp bs Dss_insts bnfs lthy =
   let
@@ -2237,9 +2265,10 @@
         val XTs = mk_Ts passiveXs;
         val YTs = mk_Ts passiveYs;
 
-        val ((((((((((((((((((((fs, fs'), (fs_copy, fs'_copy)), (gs, gs')), us),
+        val (((((((((((((((((((((fs, fs'), (fs_copy, fs'_copy)), (gs, gs')), us),
           (Jys, Jys')), (Jys_copy, Jys'_copy)), set_induct_phiss), JRs), Jphis),
-          B1s), B2s), AXs), Xs), f1s), f2s), p1s), p2s), ps), (ys, ys')), names_lthy) = names_lthy
+          B1s), B2s), AXs), Xs), f1s), f2s), p1s), p2s), ps), (ys, ys')), (ys_copy, ys'_copy)),
+          names_lthy) = names_lthy
           |> mk_Frees' "f" fTs
           ||>> mk_Frees' "f" fTs
           ||>> mk_Frees' "g" gTs
@@ -2258,6 +2287,7 @@
           ||>> mk_Frees "p1" p1Ts
           ||>> mk_Frees "p2" p2Ts
           ||>> mk_Frees "p" pTs
+          ||>> mk_Frees' "y" passiveAs
           ||>> mk_Frees' "y" passiveAs;
 
         val map_FTFT's = map2 (fn Ds =>
@@ -2601,48 +2631,7 @@
         val tacss = map9 mk_tactics map_id_tacs map_comp_tacs map_cong_tacs set_nat_tacss bd_co_tacs
           bd_cinf_tacs set_bd_tacss in_bd_tacs map_wpull_tacs;
 
-        val fld_witss =
-          let
-            val witss = map2 (fn Ds => fn bnf => mk_wits_of_bnf
-              (replicate (nwits_of_bnf bnf) Ds)
-              (replicate (nwits_of_bnf bnf) (passiveAs @ Ts)) bnf) Dss bnfs;
-            fun close_wit (I, wit) = fold_rev Term.absfree (map (nth ys') I) wit;
-            fun wit_apply (arg_I, arg_wit) (fun_I, fun_wit) =
-              (union (op =) arg_I fun_I, fun_wit $ arg_wit);
-
-            fun gen_arg support i =
-              if i < m then [([i], nth ys i)]
-              else maps (mk_wit support (nth flds (i - m)) (i - m)) (nth support (i - m))
-            and mk_wit support fld i (I, wit) =
-              let val args = map (gen_arg (nth_map i (remove (op =) (I, wit)) support)) I;
-              in
-                (args, [([], wit)])
-                |-> fold (map_product wit_apply)
-                |> map (apsnd (fn t => fld $ t))
-                |> minimize_wits
-              end;
-          in
-            map3 (fn fld => fn i => map close_wit o minimize_wits o maps (mk_wit witss fld i))
-              flds (0 upto n - 1) witss
-          end;
-
-        val wit_tac = mk_wit_tac n unf_fld_thms (flat set_simp_thmss) (maps wit_thms_of_bnf bnfs);
-
-        val (Jbnfs, lthy) =
-          fold_map6 (fn tacs => fn b => fn map => fn sets => fn T => fn wits =>
-            bnf_def Dont_Inline user_policy I tacs wit_tac (SOME deads)
-              ((((b, fold_rev Term.absfree fs' map), sets), absdummy T bd), wits))
-          tacss bs fs_maps setss_by_bnf Ts fld_witss lthy;
-
-        val fold_maps = Local_Defs.fold lthy (map (fn bnf =>
-          mk_unabs_def m (map_def_of_bnf bnf RS @{thm meta_eq_to_obj_eq})) Jbnfs);
-
-        val fold_sets = Local_Defs.fold lthy (maps (fn bnf =>
-         map (fn thm => thm RS @{thm meta_eq_to_obj_eq}) (set_defs_of_bnf bnf)) Jbnfs);
-
-        val timer = time (timer "registered new codatatypes as BNFs");
-
-        val (set_incl_thmss, set_set_incl_thmsss, set_induct_thms) =
+        val (hset_unf_incl_thmss, hset_hset_unf_incl_thmsss, hset_induct_thms) =
           let
             fun tinst_of unf =
               map (SOME o certify lthy) (unf :: remove (op =) unf unfs);
@@ -2651,19 +2640,19 @@
               (map Logic.varifyT_global (deads @ allAs) ~~ (deads @ passiveAs @ Ts));
             val set_incl_thmss =
               map2 (fn unf => map (singleton (Proof_Context.export names_lthy lthy) o
-                fold_sets o Drule.instantiate' [] (tinst_of' unf) o
+                Drule.instantiate' [] (tinst_of' unf) o
                 Thm.instantiate (Tinst, []) o Drule.zero_var_indexes))
               unfs set_incl_hset_thmss;
 
             val tinst = interleave (map (SOME o certify lthy) unfs) (replicate n NONE)
             val set_minimal_thms =
-              map (fold_sets o Drule.instantiate' [] tinst o Thm.instantiate (Tinst, []) o
+              map (Drule.instantiate' [] tinst o Thm.instantiate (Tinst, []) o
                 Drule.zero_var_indexes)
               hset_minimal_thms;
 
             val set_set_incl_thmsss =
               map2 (fn unf => map (map (singleton (Proof_Context.export names_lthy lthy) o
-                fold_sets o Drule.instantiate' [] (NONE :: tinst_of' unf) o
+                Drule.instantiate' [] (NONE :: tinst_of' unf) o
                 Thm.instantiate (Tinst, []) o Drule.zero_var_indexes)))
               unfs set_hset_incl_hset_thmsss;
 
@@ -2682,7 +2671,7 @@
               map6 (fn set_minimal => fn set_set_inclss => fn jsets => fn y => fn y' => fn phis =>
                 ((set_minimal
                   |> Drule.instantiate' [] (mk_induct_tinst phis jsets y y')
-                  |> fold_sets |> Local_Defs.unfold lthy incls) OF
+                  |> Local_Defs.unfold lthy incls) OF
                   (replicate n ballI @
                     maps (map (fn thm => thm RS @{thm subset_CollectI})) set_set_inclss))
                 |> singleton (Proof_Context.export names_lthy lthy)
@@ -2692,6 +2681,158 @@
             (set_incl_thmss, set_set_incl_thmsss, set_induct_thms)
           end;
 
+        fun close_wit I wit = (I, fold_rev Term.absfree (map (nth ys') I) wit);
+
+        val all_unitTs = replicate live HOLogic.unitT;
+        val unitTs = replicate n HOLogic.unitT;
+        val unit_funs = replicate n (Term.absdummy HOLogic.unitT HOLogic.unit);
+        fun mk_map_args I =
+          map (fn i =>
+            if member (op =) I i then Term.absdummy HOLogic.unitT (nth ys i)
+            else mk_undefined (HOLogic.unitT --> nth passiveAs i))
+          (0 upto (m - 1));
+
+        fun mk_nat_wit Ds bnf (I, wit) () =
+          let
+            val passiveI = filter (fn i => i < m) I;
+            val map_args = mk_map_args passiveI;
+          in
+            Term.absdummy HOLogic.unitT (Term.list_comb
+              (mk_map_of_bnf Ds all_unitTs (passiveAs @ unitTs) bnf, map_args @ unit_funs) $ wit)
+          end;
+
+        fun mk_dummy_wit Ds bnf I =
+          let
+            val map_args = mk_map_args I;
+          in
+            Term.absdummy HOLogic.unitT (Term.list_comb
+              (mk_map_of_bnf Ds all_unitTs (passiveAs @ unitTs) bnf, map_args @ unit_funs) $
+              mk_undefined (mk_T_of_bnf Ds all_unitTs bnf))
+          end;
+
+        val nat_witss =
+          map3 (fn i => fn Ds => fn bnf => mk_wits_of_bnf (replicate (nwits_of_bnf bnf) Ds)
+            (replicate (nwits_of_bnf bnf) (replicate live HOLogic.unitT)) bnf
+            |> map (fn (I, wit) =>
+              (I, Lazy.lazy (mk_nat_wit Ds bnf (I, Term.list_comb (wit, map (K HOLogic.unit) I))))))
+          ks Dss bnfs;
+
+        val nat_wit_thmss = map2 (curry op ~~) nat_witss (map wit_thmss_of_bnf bnfs)
+
+        val Iss = map (map fst) nat_witss;
+
+        fun filter_wits (I, wit) =
+          let val J = filter (fn i => i < m) I;
+          in (J, (length J < length I, wit)) end;
+
+        val wit_treess = map_index (fn (i, Is) =>
+          map_index (finish Iss m [i+m] (i+m)) Is) Iss
+          |> map (minimize_wits o map filter_wits o minimize_wits o flat);
+
+        val coind_wit_argsss =
+          map (map (tree_to_coind_wits nat_wit_thmss o snd o snd) o filter (fst o snd)) wit_treess;
+
+        val nonredundant_coind_wit_argsss =
+          fold (fn i => fn argsss =>
+            nth_map (i - 1) (filter_out (fn xs =>
+              exists (fn ys =>
+                let
+                  val xs' = (map (fst o fst) xs, snd (fst (hd xs)));
+                  val ys' = (map (fst o fst) ys, snd (fst (hd ys)));
+                in
+                  eq_pair (subset (op =)) (eq_set (op =)) (xs', ys') andalso not (fst xs' = fst ys')
+                end)
+              (flat argsss)))
+            argsss)
+          ks coind_wit_argsss;
+
+        fun prepare_args args =
+          let
+            val I = snd (fst (hd args));
+            val (dummys, args') =
+              map_split (fn i =>
+                (case find_first (fn arg => fst (fst arg) = i - 1) args of
+                  SOME (_, ((_, wit), thms)) => (NONE, (Lazy.force wit, thms))
+                | NONE =>
+                  (SOME (i - 1), (mk_dummy_wit (nth Dss (i - 1)) (nth bnfs (i - 1)) I, []))))
+              ks;
+          in
+            ((I, dummys), apsnd flat (split_list args'))
+          end;
+
+        fun mk_coind_wits ((I, dummys), (args, thms)) =
+          ((I, dummys), (map (fn i => mk_coiter Ts args i $ HOLogic.unit) ks, thms));
+
+        val coind_witss =
+          maps (map (mk_coind_wits o prepare_args)) nonredundant_coind_wit_argsss;
+
+        val _ = (warning o PolyML.makestring) (map length coind_wit_argsss)
+        val _ = (warning o PolyML.makestring) (map length nonredundant_coind_wit_argsss)
+
+        fun mk_coind_wit_thms ((I, dummys), (wits, wit_thms)) =
+          let
+            fun mk_goal sets y y_copy y'_copy j =
+              let
+                fun mk_conjunct set z dummy wit =
+                  mk_Ball (set $ z) (Term.absfree y'_copy
+                    (if dummy = NONE orelse member (op =) I (j - 1) then
+                      HOLogic.mk_imp (HOLogic.mk_eq (z, wit),
+                        if member (op =) I (j - 1) then HOLogic.mk_eq (y_copy, y)
+                        else @{term False})
+                    else @{term True}));
+              in
+                fold_rev Logic.all (map (nth ys) I @ Jzs) (HOLogic.mk_Trueprop
+                  (Library.foldr1 HOLogic.mk_conj (map4 mk_conjunct sets Jzs dummys wits)))
+              end;
+            val goals = map5 mk_goal setss_by_range ys ys_copy ys'_copy ls;
+          in
+            map2 (fn goal => fn induct =>
+              Skip_Proof.prove lthy [] [] goal
+               (mk_coind_wit_tac induct coiter_thms (flat set_natural'ss) wit_thms))
+            goals hset_induct_thms
+            |> map split_conj_thm
+            |> transpose
+            |> map (map_filter (try (fn thm => thm RS bspec RS mp)))
+            |> curry op ~~ (map_index Library.I (map (close_wit I) wits))
+            |> filter (fn (_, thms) => length thms = m)
+          end;
+
+        val coind_wit_thms = maps mk_coind_wit_thms coind_witss;
+
+        val witss = map2 (fn Ds => fn bnf => mk_wits_of_bnf
+          (replicate (nwits_of_bnf bnf) Ds)
+          (replicate (nwits_of_bnf bnf) (passiveAs @ Ts)) bnf) Dss bnfs;
+
+        val fld_witss =
+          map (map (uncurry close_wit o tree_to_fld_wit ys flds witss o snd o snd) o
+            filter_out (fst o snd)) wit_treess;
+
+        val all_witss =
+          fold (fn ((i, wit), thms) => fn witss =>
+            nth_map i (fn (thms', wits) => (thms @ thms', wit :: wits)) witss)
+          coind_wit_thms (map (pair []) fld_witss)
+          |> map (apsnd (map snd o minimize_wits));
+
+        val wit_tac = mk_wit_tac n unf_fld_thms (flat set_simp_thmss) (maps wit_thms_of_bnf bnfs);
+
+        val (Jbnfs, lthy) =
+          fold_map6 (fn tacs => fn b => fn map => fn sets => fn T => fn (thms, wits) =>
+            bnf_def Dont_Inline user_policy I tacs (wit_tac thms) (SOME deads)
+              ((((b, fold_rev Term.absfree fs' map), sets), absdummy T bd), wits))
+          tacss bs fs_maps setss_by_bnf Ts all_witss lthy;
+
+        val fold_maps = Local_Defs.fold lthy (map (fn bnf =>
+          mk_unabs_def m (map_def_of_bnf bnf RS @{thm meta_eq_to_obj_eq})) Jbnfs);
+
+        val fold_sets = Local_Defs.fold lthy (maps (fn bnf =>
+         map (fn thm => thm RS @{thm meta_eq_to_obj_eq}) (set_defs_of_bnf bnf)) Jbnfs);
+
+        val timer = time (timer "registered new codatatypes as BNFs");
+
+        val set_incl_thmss = map (map fold_sets) hset_unf_incl_thmss;
+        val set_set_incl_thmsss = map (map (map fold_sets)) hset_hset_unf_incl_thmsss;
+        val set_induct_thms = map fold_sets hset_induct_thms;
+
         val rels = map2 (fn Ds => mk_rel_of_bnf Ds (passiveAs @ Ts) (passiveBs @ Ts')) Dss bnfs;
         val Jrels = map (mk_rel_of_bnf deads passiveAs passiveBs) Jbnfs;
         val preds = map2 (fn Ds => mk_pred_of_bnf Ds (passiveAs @ Ts) (passiveBs @ Ts')) Dss bnfs;
--- a/src/HOL/Codatatype/Tools/bnf_gfp_tactics.ML	Mon Sep 03 17:56:39 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp_tactics.ML	Mon Sep 03 17:57:34 2012 +0200
@@ -28,6 +28,8 @@
   val mk_coalg_set_tac: thm -> tactic
   val mk_coalg_thePull_tac: int -> thm -> thm list -> thm list list -> (int -> tactic) list ->
     {prems: 'a, context: Proof.context} -> tactic
+  val mk_coind_wit_tac: thm -> thm list -> thm list -> thm list ->
+    {prems: 'a, context: Proof.context} -> tactic
   val mk_coiter_unique_mor_tac: thm list -> thm -> thm -> thm list -> tactic
   val mk_col_bd_tac: int -> int -> cterm option list -> thm list -> thm list -> thm -> thm ->
     thm list list -> tactic
@@ -114,7 +116,7 @@
   val mk_unf_o_fld_tac: thm -> thm -> thm -> thm -> thm list ->
     {prems: 'a, context: Proof.context} -> tactic
   val mk_unique_mor_tac: thm list -> thm -> tactic
-  val mk_wit_tac: int -> thm list -> thm list -> thm list ->
+  val mk_wit_tac: int -> thm list -> thm list -> thm list -> thm list ->
     {prems: 'a, context: Proof.context} -> tactic
   val mk_wpull_tac: int -> thm -> thm -> thm -> thm -> thm -> thm list -> thm list -> tactic
 end;
@@ -1463,7 +1465,8 @@
         rtac @{thm prod_caseI}, etac conjI, etac conjI, atac])
     (pick_cols ~~ hset_defs)] 1;
 
-fun mk_wit_tac n unf_flds set_simp wit {context = ctxt, prems = _} =
+fun mk_wit_tac n unf_flds set_simp wit coind_wits {context = ctxt, prems = _} =
+  ALLGOALS (TRY o (eresolve_tac coind_wits THEN' rtac refl)) THEN
   REPEAT_DETERM (atac 1 ORELSE
     EVERY' [dtac @{thm set_rev_mp}, rtac equalityD1, resolve_tac set_simp,
     K (Local_Defs.unfold_tac ctxt unf_flds),
@@ -1476,6 +1479,13 @@
           EVERY' [hyp_subst_tac, dtac @{thm set_rev_mp}, rtac equalityD1, resolve_tac set_simp,
             K (Local_Defs.unfold_tac ctxt unf_flds), REPEAT_DETERM_N n o etac UnE]))))] 1);
 
+fun mk_coind_wit_tac induct coiters set_nats wits {context = ctxt, prems = _} =
+  rtac induct 1 THEN ALLGOALS (TRY o rtac impI THEN' TRY o hyp_subst_tac) THEN
+  Local_Defs.unfold_tac ctxt (coiters @ set_nats @ @{thms image_id id_apply}) THEN
+  ALLGOALS (REPEAT_DETERM o etac imageE THEN' TRY o hyp_subst_tac) THEN
+  ALLGOALS (TRY o
+    FIRST' [rtac TrueI, rtac refl, etac (refl RSN (2, mp)), dresolve_tac wits THEN' etac FalseE])
+
 fun mk_rel_unfold_tac in_Jrels i in_rel map_comp map_cong map_simp set_simps unf_inject unf_fld
   set_naturals set_incls set_set_inclss =
   let