src/HOL/Library/bnf_lfp_countable.ML
author blanchet
Wed Sep 24 15:45:55 2014 +0200 (2014-09-24)
changeset 58425 246985c6b20b
parent 58315 6d8458bc6e27
child 58459 f70bffabd7cf
permissions -rw-r--r--
simpler proof
     1 (*  Title:      HOL/Library/bnf_lfp_countable.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2014
     4 
     5 Countability tactic for BNF datatypes.
     6 *)
     7 
     8 signature BNF_LFP_COUNTABLE =
     9 sig
    10   val derive_encode_injectives_thms: Proof.context -> string list -> thm list
    11   val countable_datatype_tac: Proof.context -> tactic
    12 end;
    13 
    14 structure BNF_LFP_Countable : BNF_LFP_COUNTABLE =
    15 struct
    16 
    17 open BNF_FP_Rec_Sugar_Util
    18 open BNF_Def
    19 open BNF_Util
    20 open BNF_Tactics
    21 open BNF_FP_Util
    22 open BNF_FP_Def_Sugar
    23 
    24 val countableS = @{sort countable};
    25 
    26 fun nchotomy_tac nchotomy =
    27   HEADGOAL (rtac (nchotomy RS @{thm all_reg[rotated]}) THEN'
    28     REPEAT_ALL_NEW (resolve_tac [allI, impI] ORELSE' eresolve_tac [exE, disjE]));
    29 
    30 fun meta_spec_mp_tac 0 = K all_tac
    31   | meta_spec_mp_tac depth =
    32     dtac meta_spec THEN' meta_spec_mp_tac (depth - 1) THEN' dtac meta_mp THEN' atac;
    33 
    34 val use_induction_hypothesis_tac =
    35   DEEPEN (1, 64 (* large number *))
    36     (fn depth => meta_spec_mp_tac depth THEN' etac allE THEN' etac impE THEN' atac THEN' atac) 0;
    37 
    38 val same_ctr_simps = @{thms sum_encode_eq prod_encode_eq sum.inject prod.inject to_nat_split
    39   id_apply snd_conv simp_thms};
    40 val distinct_ctrs_simps = @{thms sum_encode_eq sum.inject sum.distinct simp_thms};
    41 
    42 fun same_ctr_tac ctxt injects recs map_congs' inj_map_strongs' =
    43   HEADGOAL (asm_full_simp_tac
    44       (ss_only (injects @ recs @ map_congs' @ same_ctr_simps) ctxt) THEN_MAYBE'
    45     TRY o REPEAT_ALL_NEW (rtac conjI) THEN_ALL_NEW
    46     REPEAT_ALL_NEW (eresolve_tac (conjE :: inj_map_strongs')) THEN_ALL_NEW
    47     (atac ORELSE' use_induction_hypothesis_tac));
    48 
    49 fun distinct_ctrs_tac ctxt recs =
    50   HEADGOAL (asm_full_simp_tac (ss_only (recs @ distinct_ctrs_simps) ctxt));
    51 
    52 fun mk_encode_injective_tac ctxt n nchotomy injects recs map_comps' inj_map_strongs' =
    53   let val ks = 1 upto n in
    54     EVERY (maps (fn k => nchotomy_tac nchotomy :: map (fn k' =>
    55       if k = k' then same_ctr_tac ctxt injects recs map_comps' inj_map_strongs'
    56       else distinct_ctrs_tac ctxt recs) ks) ks)
    57   end;
    58 
    59 fun mk_encode_injectives_tac ctxt ns induct nchotomys injectss recss map_comps' inj_map_strongs' =
    60   HEADGOAL (rtac induct) THEN
    61   EVERY (map4 (fn n => fn nchotomy => fn injects => fn recs =>
    62       mk_encode_injective_tac ctxt n nchotomy injects recs map_comps' inj_map_strongs')
    63     ns nchotomys injectss recss);
    64 
    65 fun endgame_tac ctxt encode_injectives =
    66   unfold_thms_tac ctxt @{thms inj_on_def ball_UNIV} THEN
    67   ALLGOALS (rtac exI THEN' rtac allI THEN' resolve_tac encode_injectives);
    68 
    69 fun encode_sumN n k t =
    70   Balanced_Tree.access {init = t,
    71       left = fn t => @{const sum_encode} $ (@{const Inl (nat, nat)} $ t),
    72       right = fn t => @{const sum_encode} $ (@{const Inr (nat, nat)} $ t)}
    73     n k;
    74 
    75 fun encode_tuple [] = @{term "0 :: nat"}
    76   | encode_tuple ts =
    77     Balanced_Tree.make (fn (t, u) => @{const prod_encode} $ (@{const Pair (nat, nat)} $ u $ t)) ts;
    78 
    79 fun mk_encode_funs ctxt fpTs ns ctrss0 recs0 =
    80   let
    81     val thy = Proof_Context.theory_of ctxt;
    82 
    83     fun check_countable T =
    84       Sign.of_sort thy (T, countableS) orelse
    85       raise TYPE ("Type is not of sort " ^ Syntax.string_of_sort ctxt countableS, [T], []);
    86 
    87     fun mk_to_nat_checked T =
    88       Const (@{const_name to_nat}, tap check_countable T --> HOLogic.natT);
    89 
    90     val nn = length ns;
    91     val recs as rec1 :: _ = map2 (mk_co_rec thy Least_FP (replicate nn HOLogic.natT)) fpTs recs0;
    92     val arg_Ts = binder_fun_types (fastype_of rec1);
    93     val arg_Tss = Library.unflat ctrss0 arg_Ts;
    94 
    95     fun mk_U (Type (@{type_name prod}, [T1, T2])) =
    96         if member (op =) fpTs T1 then T2 else HOLogic.mk_prodT (mk_U T1, mk_U T2)
    97       | mk_U (Type (s, Ts)) = Type (s, map mk_U Ts)
    98       | mk_U T = T;
    99 
   100     fun mk_nat (j, T) =
   101       if T = HOLogic.natT then
   102         SOME (Bound j)
   103       else if member (op =) fpTs T then
   104         NONE
   105       else if exists_subtype_in fpTs T then
   106         let val U = mk_U T in
   107           SOME (mk_to_nat_checked U $ (build_map ctxt [] (snd_const o fst) (T, U) $ Bound j))
   108         end
   109       else
   110         SOME (mk_to_nat_checked T $ Bound j);
   111 
   112     fun mk_arg n (k, arg_T) =
   113       let
   114         val bound_Ts = rev (binder_types arg_T);
   115         val nats = map_filter mk_nat (tag_list 0 bound_Ts);
   116       in
   117         fold (fn T => fn t => Abs (Name.uu, T, t)) bound_Ts (encode_sumN n k (encode_tuple nats))
   118       end;
   119 
   120     val argss = map2 (map o mk_arg) ns (map (tag_list 1) arg_Tss);
   121   in
   122     map (fn recx => Term.list_comb (recx, flat argss)) recs
   123   end;
   124 
   125 fun derive_encode_injectives_thms _ [] = []
   126   | derive_encode_injectives_thms ctxt fpT_names0 =
   127     let
   128       fun not_datatype s = error (quote s ^ " is not a datatype");
   129       fun not_mutually_recursive ss = error (commas ss ^ " are not mutually recursive datatypes");
   130 
   131       fun lfp_sugar_of s =
   132         (case fp_sugar_of ctxt s of
   133           SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
   134         | _ => not_datatype s);
   135 
   136       val fpTs0 as Type (_, var_As) :: _ =
   137         map (#T o lfp_sugar_of o fst o dest_Type) (#Ts (#fp_res (lfp_sugar_of (hd fpT_names0))));
   138       val fpT_names = map (fst o dest_Type) fpTs0;
   139 
   140       val (As_names, _) = Variable.variant_fixes (map (fn TVar ((s, _), _) => s) var_As) ctxt;
   141       val As =
   142         map2 (fn s => fn TVar (_, S) => TFree (s, union (op =) countableS S))
   143           As_names var_As;
   144       val fpTs = map (fn s => Type (s, As)) fpT_names;
   145 
   146       val _ = subset (op =) (fpT_names0, fpT_names) orelse not_mutually_recursive fpT_names0;
   147 
   148       fun mk_conjunct fpT x encode_fun =
   149         HOLogic.all_const fpT $ Abs (Name.uu, fpT,
   150           HOLogic.mk_imp (HOLogic.mk_eq (encode_fun $ x, encode_fun $ Bound 0),
   151             HOLogic.eq_const fpT $ x $ Bound 0));
   152 
   153       val fp_sugars as {fp_nesting_bnfs, common_co_inducts = induct :: _, ...} :: _ =
   154         map (the o fp_sugar_of ctxt o fst o dest_Type) fpTs0;
   155       val ctr_sugars = map #ctr_sugar fp_sugars;
   156 
   157       val ctrss0 = map #ctrs ctr_sugars;
   158       val ns = map length ctrss0;
   159       val recs0 = map #co_rec fp_sugars;
   160       val nchotomys = map #nchotomy ctr_sugars;
   161       val injectss = map #injects ctr_sugars;
   162       val rec_thmss = map #co_rec_thms fp_sugars;
   163       val map_comps' = map (unfold_thms ctxt @{thms comp_def} o map_comp_of_bnf) fp_nesting_bnfs;
   164       val inj_map_strongs' = map (Thm.permute_prems 0 ~1 o inj_map_strong_of_bnf) fp_nesting_bnfs;
   165 
   166       val (xs, names_ctxt) = ctxt |> mk_Frees "x" fpTs;
   167 
   168       val conjuncts = map3 mk_conjunct fpTs xs (mk_encode_funs ctxt fpTs ns ctrss0 recs0);
   169       val goal = HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj conjuncts);
   170     in
   171       Goal.prove (*no sorry*) ctxt [] [] goal (fn {context = ctxt, prems = _} =>
   172         mk_encode_injectives_tac ctxt ns induct nchotomys injectss rec_thmss map_comps'
   173           inj_map_strongs')
   174       |> HOLogic.conj_elims
   175       |> Proof_Context.export names_ctxt ctxt
   176       |> map Thm.close_derivation
   177     end;
   178 
   179 fun get_countable_goal_type_name (@{const Trueprop} $ (Const (@{const_name Ex}, _)
   180     $ Abs (_, Type (_, [Type (s, _), _]), Const (@{const_name inj_on}, _) $ Bound 0
   181         $ Const (@{const_name top}, _)))) = s
   182   | get_countable_goal_type_name _ = error "Wrong goal format for datatype countability tactic";
   183 
   184 fun core_countable_datatype_tac ctxt st =
   185   let val T_names = map get_countable_goal_type_name (Thm.prems_of st) in
   186     endgame_tac ctxt (derive_encode_injectives_thms ctxt T_names) st
   187   end;
   188 
   189 fun countable_datatype_tac ctxt =
   190   TRY (Class.intro_classes_tac []) THEN core_countable_datatype_tac ctxt;
   191 
   192 end;