src/HOL/Codatatype/Tools/bnf_gfp.ML
changeset 49104 6defdacd595a
parent 49074 d8af889dcbe3
child 49105 a426099dc343
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_gfp.ML	Mon Sep 03 17:56:39 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML	Mon Sep 03 17:57:34 2012 +0200
     1.3 @@ -22,6 +22,34 @@
     1.4  open BNF_GFP_Util
     1.5  open BNF_GFP_Tactics
     1.6  
     1.7 +datatype wit_tree = Leaf of int | Node of (int * int * int list) * wit_tree list;
     1.8 +
     1.9 +fun mk_tree_args (I, T) (I', Ts) = (sort_distinct int_ord (I @ I'), T :: Ts);
    1.10 +
    1.11 +fun finish Iss m seen i (nwit, I) =
    1.12 +  let
    1.13 +    val treess = map (fn j =>
    1.14 +        if j < m orelse member (op =) seen j then [([j], Leaf j)]
    1.15 +        else
    1.16 +          map_index (finish Iss m (insert (op =) j seen) j) (nth Iss (j - m))
    1.17 +          |> flat
    1.18 +          |> minimize_wits)
    1.19 +      I;
    1.20 +  in
    1.21 +    map (fn (I, t) => (I, Node ((i - m, nwit, filter (fn i => i < m) I), t)))
    1.22 +      (fold_rev (map_product mk_tree_args) treess [([], [])])
    1.23 +    |> minimize_wits
    1.24 +  end;
    1.25 +
    1.26 +fun tree_to_fld_wit vars _ _ (Leaf j) = ([j], nth vars j)
    1.27 +  | tree_to_fld_wit vars flds witss (Node ((i, nwit, I), subtrees)) =
    1.28 +     (I, nth flds i $ (Term.list_comb (snd (nth (nth witss i) nwit),
    1.29 +       map (snd o tree_to_fld_wit vars flds witss) subtrees)));
    1.30 +
    1.31 +fun tree_to_coind_wits _ (Leaf j) = []
    1.32 +  | tree_to_coind_wits lwitss (Node ((i, nwit, I), subtrees)) =
    1.33 +     ((i, I), nth (nth lwitss i) nwit) :: maps (tree_to_coind_wits lwitss) subtrees;
    1.34 +
    1.35  (*all bnfs have the same lives*)
    1.36  fun bnf_gfp bs Dss_insts bnfs lthy =
    1.37    let
    1.38 @@ -2237,9 +2265,10 @@
    1.39          val XTs = mk_Ts passiveXs;
    1.40          val YTs = mk_Ts passiveYs;
    1.41  
    1.42 -        val ((((((((((((((((((((fs, fs'), (fs_copy, fs'_copy)), (gs, gs')), us),
    1.43 +        val (((((((((((((((((((((fs, fs'), (fs_copy, fs'_copy)), (gs, gs')), us),
    1.44            (Jys, Jys')), (Jys_copy, Jys'_copy)), set_induct_phiss), JRs), Jphis),
    1.45 -          B1s), B2s), AXs), Xs), f1s), f2s), p1s), p2s), ps), (ys, ys')), names_lthy) = names_lthy
    1.46 +          B1s), B2s), AXs), Xs), f1s), f2s), p1s), p2s), ps), (ys, ys')), (ys_copy, ys'_copy)),
    1.47 +          names_lthy) = names_lthy
    1.48            |> mk_Frees' "f" fTs
    1.49            ||>> mk_Frees' "f" fTs
    1.50            ||>> mk_Frees' "g" gTs
    1.51 @@ -2258,6 +2287,7 @@
    1.52            ||>> mk_Frees "p1" p1Ts
    1.53            ||>> mk_Frees "p2" p2Ts
    1.54            ||>> mk_Frees "p" pTs
    1.55 +          ||>> mk_Frees' "y" passiveAs
    1.56            ||>> mk_Frees' "y" passiveAs;
    1.57  
    1.58          val map_FTFT's = map2 (fn Ds =>
    1.59 @@ -2601,48 +2631,7 @@
    1.60          val tacss = map9 mk_tactics map_id_tacs map_comp_tacs map_cong_tacs set_nat_tacss bd_co_tacs
    1.61            bd_cinf_tacs set_bd_tacss in_bd_tacs map_wpull_tacs;
    1.62  
    1.63 -        val fld_witss =
    1.64 -          let
    1.65 -            val witss = map2 (fn Ds => fn bnf => mk_wits_of_bnf
    1.66 -              (replicate (nwits_of_bnf bnf) Ds)
    1.67 -              (replicate (nwits_of_bnf bnf) (passiveAs @ Ts)) bnf) Dss bnfs;
    1.68 -            fun close_wit (I, wit) = fold_rev Term.absfree (map (nth ys') I) wit;
    1.69 -            fun wit_apply (arg_I, arg_wit) (fun_I, fun_wit) =
    1.70 -              (union (op =) arg_I fun_I, fun_wit $ arg_wit);
    1.71 -
    1.72 -            fun gen_arg support i =
    1.73 -              if i < m then [([i], nth ys i)]
    1.74 -              else maps (mk_wit support (nth flds (i - m)) (i - m)) (nth support (i - m))
    1.75 -            and mk_wit support fld i (I, wit) =
    1.76 -              let val args = map (gen_arg (nth_map i (remove (op =) (I, wit)) support)) I;
    1.77 -              in
    1.78 -                (args, [([], wit)])
    1.79 -                |-> fold (map_product wit_apply)
    1.80 -                |> map (apsnd (fn t => fld $ t))
    1.81 -                |> minimize_wits
    1.82 -              end;
    1.83 -          in
    1.84 -            map3 (fn fld => fn i => map close_wit o minimize_wits o maps (mk_wit witss fld i))
    1.85 -              flds (0 upto n - 1) witss
    1.86 -          end;
    1.87 -
    1.88 -        val wit_tac = mk_wit_tac n unf_fld_thms (flat set_simp_thmss) (maps wit_thms_of_bnf bnfs);
    1.89 -
    1.90 -        val (Jbnfs, lthy) =
    1.91 -          fold_map6 (fn tacs => fn b => fn map => fn sets => fn T => fn wits =>
    1.92 -            bnf_def Dont_Inline user_policy I tacs wit_tac (SOME deads)
    1.93 -              ((((b, fold_rev Term.absfree fs' map), sets), absdummy T bd), wits))
    1.94 -          tacss bs fs_maps setss_by_bnf Ts fld_witss lthy;
    1.95 -
    1.96 -        val fold_maps = Local_Defs.fold lthy (map (fn bnf =>
    1.97 -          mk_unabs_def m (map_def_of_bnf bnf RS @{thm meta_eq_to_obj_eq})) Jbnfs);
    1.98 -
    1.99 -        val fold_sets = Local_Defs.fold lthy (maps (fn bnf =>
   1.100 -         map (fn thm => thm RS @{thm meta_eq_to_obj_eq}) (set_defs_of_bnf bnf)) Jbnfs);
   1.101 -
   1.102 -        val timer = time (timer "registered new codatatypes as BNFs");
   1.103 -
   1.104 -        val (set_incl_thmss, set_set_incl_thmsss, set_induct_thms) =
   1.105 +        val (hset_unf_incl_thmss, hset_hset_unf_incl_thmsss, hset_induct_thms) =
   1.106            let
   1.107              fun tinst_of unf =
   1.108                map (SOME o certify lthy) (unf :: remove (op =) unf unfs);
   1.109 @@ -2651,19 +2640,19 @@
   1.110                (map Logic.varifyT_global (deads @ allAs) ~~ (deads @ passiveAs @ Ts));
   1.111              val set_incl_thmss =
   1.112                map2 (fn unf => map (singleton (Proof_Context.export names_lthy lthy) o
   1.113 -                fold_sets o Drule.instantiate' [] (tinst_of' unf) o
   1.114 +                Drule.instantiate' [] (tinst_of' unf) o
   1.115                  Thm.instantiate (Tinst, []) o Drule.zero_var_indexes))
   1.116                unfs set_incl_hset_thmss;
   1.117  
   1.118              val tinst = interleave (map (SOME o certify lthy) unfs) (replicate n NONE)
   1.119              val set_minimal_thms =
   1.120 -              map (fold_sets o Drule.instantiate' [] tinst o Thm.instantiate (Tinst, []) o
   1.121 +              map (Drule.instantiate' [] tinst o Thm.instantiate (Tinst, []) o
   1.122                  Drule.zero_var_indexes)
   1.123                hset_minimal_thms;
   1.124  
   1.125              val set_set_incl_thmsss =
   1.126                map2 (fn unf => map (map (singleton (Proof_Context.export names_lthy lthy) o
   1.127 -                fold_sets o Drule.instantiate' [] (NONE :: tinst_of' unf) o
   1.128 +                Drule.instantiate' [] (NONE :: tinst_of' unf) o
   1.129                  Thm.instantiate (Tinst, []) o Drule.zero_var_indexes)))
   1.130                unfs set_hset_incl_hset_thmsss;
   1.131  
   1.132 @@ -2682,7 +2671,7 @@
   1.133                map6 (fn set_minimal => fn set_set_inclss => fn jsets => fn y => fn y' => fn phis =>
   1.134                  ((set_minimal
   1.135                    |> Drule.instantiate' [] (mk_induct_tinst phis jsets y y')
   1.136 -                  |> fold_sets |> Local_Defs.unfold lthy incls) OF
   1.137 +                  |> Local_Defs.unfold lthy incls) OF
   1.138                    (replicate n ballI @
   1.139                      maps (map (fn thm => thm RS @{thm subset_CollectI})) set_set_inclss))
   1.140                  |> singleton (Proof_Context.export names_lthy lthy)
   1.141 @@ -2692,6 +2681,158 @@
   1.142              (set_incl_thmss, set_set_incl_thmsss, set_induct_thms)
   1.143            end;
   1.144  
   1.145 +        fun close_wit I wit = (I, fold_rev Term.absfree (map (nth ys') I) wit);
   1.146 +
   1.147 +        val all_unitTs = replicate live HOLogic.unitT;
   1.148 +        val unitTs = replicate n HOLogic.unitT;
   1.149 +        val unit_funs = replicate n (Term.absdummy HOLogic.unitT HOLogic.unit);
   1.150 +        fun mk_map_args I =
   1.151 +          map (fn i =>
   1.152 +            if member (op =) I i then Term.absdummy HOLogic.unitT (nth ys i)
   1.153 +            else mk_undefined (HOLogic.unitT --> nth passiveAs i))
   1.154 +          (0 upto (m - 1));
   1.155 +
   1.156 +        fun mk_nat_wit Ds bnf (I, wit) () =
   1.157 +          let
   1.158 +            val passiveI = filter (fn i => i < m) I;
   1.159 +            val map_args = mk_map_args passiveI;
   1.160 +          in
   1.161 +            Term.absdummy HOLogic.unitT (Term.list_comb
   1.162 +              (mk_map_of_bnf Ds all_unitTs (passiveAs @ unitTs) bnf, map_args @ unit_funs) $ wit)
   1.163 +          end;
   1.164 +
   1.165 +        fun mk_dummy_wit Ds bnf I =
   1.166 +          let
   1.167 +            val map_args = mk_map_args I;
   1.168 +          in
   1.169 +            Term.absdummy HOLogic.unitT (Term.list_comb
   1.170 +              (mk_map_of_bnf Ds all_unitTs (passiveAs @ unitTs) bnf, map_args @ unit_funs) $
   1.171 +              mk_undefined (mk_T_of_bnf Ds all_unitTs bnf))
   1.172 +          end;
   1.173 +
   1.174 +        val nat_witss =
   1.175 +          map3 (fn i => fn Ds => fn bnf => mk_wits_of_bnf (replicate (nwits_of_bnf bnf) Ds)
   1.176 +            (replicate (nwits_of_bnf bnf) (replicate live HOLogic.unitT)) bnf
   1.177 +            |> map (fn (I, wit) =>
   1.178 +              (I, Lazy.lazy (mk_nat_wit Ds bnf (I, Term.list_comb (wit, map (K HOLogic.unit) I))))))
   1.179 +          ks Dss bnfs;
   1.180 +
   1.181 +        val nat_wit_thmss = map2 (curry op ~~) nat_witss (map wit_thmss_of_bnf bnfs)
   1.182 +
   1.183 +        val Iss = map (map fst) nat_witss;
   1.184 +
   1.185 +        fun filter_wits (I, wit) =
   1.186 +          let val J = filter (fn i => i < m) I;
   1.187 +          in (J, (length J < length I, wit)) end;
   1.188 +
   1.189 +        val wit_treess = map_index (fn (i, Is) =>
   1.190 +          map_index (finish Iss m [i+m] (i+m)) Is) Iss
   1.191 +          |> map (minimize_wits o map filter_wits o minimize_wits o flat);
   1.192 +
   1.193 +        val coind_wit_argsss =
   1.194 +          map (map (tree_to_coind_wits nat_wit_thmss o snd o snd) o filter (fst o snd)) wit_treess;
   1.195 +
   1.196 +        val nonredundant_coind_wit_argsss =
   1.197 +          fold (fn i => fn argsss =>
   1.198 +            nth_map (i - 1) (filter_out (fn xs =>
   1.199 +              exists (fn ys =>
   1.200 +                let
   1.201 +                  val xs' = (map (fst o fst) xs, snd (fst (hd xs)));
   1.202 +                  val ys' = (map (fst o fst) ys, snd (fst (hd ys)));
   1.203 +                in
   1.204 +                  eq_pair (subset (op =)) (eq_set (op =)) (xs', ys') andalso not (fst xs' = fst ys')
   1.205 +                end)
   1.206 +              (flat argsss)))
   1.207 +            argsss)
   1.208 +          ks coind_wit_argsss;
   1.209 +
   1.210 +        fun prepare_args args =
   1.211 +          let
   1.212 +            val I = snd (fst (hd args));
   1.213 +            val (dummys, args') =
   1.214 +              map_split (fn i =>
   1.215 +                (case find_first (fn arg => fst (fst arg) = i - 1) args of
   1.216 +                  SOME (_, ((_, wit), thms)) => (NONE, (Lazy.force wit, thms))
   1.217 +                | NONE =>
   1.218 +                  (SOME (i - 1), (mk_dummy_wit (nth Dss (i - 1)) (nth bnfs (i - 1)) I, []))))
   1.219 +              ks;
   1.220 +          in
   1.221 +            ((I, dummys), apsnd flat (split_list args'))
   1.222 +          end;
   1.223 +
   1.224 +        fun mk_coind_wits ((I, dummys), (args, thms)) =
   1.225 +          ((I, dummys), (map (fn i => mk_coiter Ts args i $ HOLogic.unit) ks, thms));
   1.226 +
   1.227 +        val coind_witss =
   1.228 +          maps (map (mk_coind_wits o prepare_args)) nonredundant_coind_wit_argsss;
   1.229 +
   1.230 +        val _ = (warning o PolyML.makestring) (map length coind_wit_argsss)
   1.231 +        val _ = (warning o PolyML.makestring) (map length nonredundant_coind_wit_argsss)
   1.232 +
   1.233 +        fun mk_coind_wit_thms ((I, dummys), (wits, wit_thms)) =
   1.234 +          let
   1.235 +            fun mk_goal sets y y_copy y'_copy j =
   1.236 +              let
   1.237 +                fun mk_conjunct set z dummy wit =
   1.238 +                  mk_Ball (set $ z) (Term.absfree y'_copy
   1.239 +                    (if dummy = NONE orelse member (op =) I (j - 1) then
   1.240 +                      HOLogic.mk_imp (HOLogic.mk_eq (z, wit),
   1.241 +                        if member (op =) I (j - 1) then HOLogic.mk_eq (y_copy, y)
   1.242 +                        else @{term False})
   1.243 +                    else @{term True}));
   1.244 +              in
   1.245 +                fold_rev Logic.all (map (nth ys) I @ Jzs) (HOLogic.mk_Trueprop
   1.246 +                  (Library.foldr1 HOLogic.mk_conj (map4 mk_conjunct sets Jzs dummys wits)))
   1.247 +              end;
   1.248 +            val goals = map5 mk_goal setss_by_range ys ys_copy ys'_copy ls;
   1.249 +          in
   1.250 +            map2 (fn goal => fn induct =>
   1.251 +              Skip_Proof.prove lthy [] [] goal
   1.252 +               (mk_coind_wit_tac induct coiter_thms (flat set_natural'ss) wit_thms))
   1.253 +            goals hset_induct_thms
   1.254 +            |> map split_conj_thm
   1.255 +            |> transpose
   1.256 +            |> map (map_filter (try (fn thm => thm RS bspec RS mp)))
   1.257 +            |> curry op ~~ (map_index Library.I (map (close_wit I) wits))
   1.258 +            |> filter (fn (_, thms) => length thms = m)
   1.259 +          end;
   1.260 +
   1.261 +        val coind_wit_thms = maps mk_coind_wit_thms coind_witss;
   1.262 +
   1.263 +        val witss = map2 (fn Ds => fn bnf => mk_wits_of_bnf
   1.264 +          (replicate (nwits_of_bnf bnf) Ds)
   1.265 +          (replicate (nwits_of_bnf bnf) (passiveAs @ Ts)) bnf) Dss bnfs;
   1.266 +
   1.267 +        val fld_witss =
   1.268 +          map (map (uncurry close_wit o tree_to_fld_wit ys flds witss o snd o snd) o
   1.269 +            filter_out (fst o snd)) wit_treess;
   1.270 +
   1.271 +        val all_witss =
   1.272 +          fold (fn ((i, wit), thms) => fn witss =>
   1.273 +            nth_map i (fn (thms', wits) => (thms @ thms', wit :: wits)) witss)
   1.274 +          coind_wit_thms (map (pair []) fld_witss)
   1.275 +          |> map (apsnd (map snd o minimize_wits));
   1.276 +
   1.277 +        val wit_tac = mk_wit_tac n unf_fld_thms (flat set_simp_thmss) (maps wit_thms_of_bnf bnfs);
   1.278 +
   1.279 +        val (Jbnfs, lthy) =
   1.280 +          fold_map6 (fn tacs => fn b => fn map => fn sets => fn T => fn (thms, wits) =>
   1.281 +            bnf_def Dont_Inline user_policy I tacs (wit_tac thms) (SOME deads)
   1.282 +              ((((b, fold_rev Term.absfree fs' map), sets), absdummy T bd), wits))
   1.283 +          tacss bs fs_maps setss_by_bnf Ts all_witss lthy;
   1.284 +
   1.285 +        val fold_maps = Local_Defs.fold lthy (map (fn bnf =>
   1.286 +          mk_unabs_def m (map_def_of_bnf bnf RS @{thm meta_eq_to_obj_eq})) Jbnfs);
   1.287 +
   1.288 +        val fold_sets = Local_Defs.fold lthy (maps (fn bnf =>
   1.289 +         map (fn thm => thm RS @{thm meta_eq_to_obj_eq}) (set_defs_of_bnf bnf)) Jbnfs);
   1.290 +
   1.291 +        val timer = time (timer "registered new codatatypes as BNFs");
   1.292 +
   1.293 +        val set_incl_thmss = map (map fold_sets) hset_unf_incl_thmss;
   1.294 +        val set_set_incl_thmsss = map (map (map fold_sets)) hset_hset_unf_incl_thmsss;
   1.295 +        val set_induct_thms = map fold_sets hset_induct_thms;
   1.296 +
   1.297          val rels = map2 (fn Ds => mk_rel_of_bnf Ds (passiveAs @ Ts) (passiveBs @ Ts')) Dss bnfs;
   1.298          val Jrels = map (mk_rel_of_bnf deads passiveAs passiveBs) Jbnfs;
   1.299          val preds = map2 (fn Ds => mk_pred_of_bnf Ds (passiveAs @ Ts) (passiveBs @ Ts')) Dss bnfs;