added countable tactic for new-style datatypes
authorblanchet
Wed, 03 Sep 2014 00:31:38 +0200
changeset 58160 e4965b677ba9
parent 58159 e3d1912a0c8f
child 58161 deeff89d5b9e
added countable tactic for new-style datatypes
src/HOL/Library/Countable.thy
src/HOL/Library/bnf_lfp_countable.ML
--- a/src/HOL/Library/Countable.thy	Wed Sep 03 00:31:37 2014 +0200
+++ b/src/HOL/Library/Countable.thy	Wed Sep 03 00:31:38 2014 +0200
@@ -1,6 +1,7 @@
 (*  Title:      HOL/Library/Countable.thy
     Author:     Alexander Krauss, TU Muenchen
     Author:     Brian Huffman, Portland State University
+    Author:     Jasmin Blanchette, TU Muenchen
 *)
 
 header {* Encoding (almost) everything into natural numbers *}
@@ -49,10 +50,7 @@
   by (simp add: from_nat_def)
 
 
-subsection {* Countable types *}
-
-instance nat :: countable
-  by (rule countable_classI [of "id"]) simp
+subsection {* Finite types are countable *}
 
 subclass (in finite) countable
 proof
@@ -65,113 +63,8 @@
   then show "\<exists>to_nat \<Colon> 'a \<Rightarrow> nat. inj to_nat" by (rule exI[of inj])
 qed
 
-text {* Pairs *}
 
-instance prod :: (countable, countable) countable
-  by (rule countable_classI [of "\<lambda>(x, y). prod_encode (to_nat x, to_nat y)"])
-    (auto simp add: prod_encode_eq)
-
-
-text {* Sums *}
-
-instance sum :: (countable, countable) countable
-  by (rule countable_classI [of "(\<lambda>x. case x of Inl a \<Rightarrow> to_nat (False, to_nat a)
-                                     | Inr b \<Rightarrow> to_nat (True, to_nat b))"])
-    (simp split: sum.split_asm)
-
-
-text {* Integers *}
-
-instance int :: countable
-  by (rule countable_classI [of "int_encode"])
-    (simp add: int_encode_eq)
-
-
-text {* Options *}
-
-instance option :: (countable) countable
-  by (rule countable_classI [of "case_option 0 (Suc \<circ> to_nat)"])
-    (simp split: option.split_asm)
-
-
-text {* Lists *}
-
-instance list :: (countable) countable
-  by (rule countable_classI [of "list_encode \<circ> map to_nat"])
-    (simp add: list_encode_eq)
-
-
-text {* Further *}
-
-instance String.literal :: countable
-  by (rule countable_classI [of "to_nat o String.explode"])
-    (auto simp add: explode_inject)
-
-text {* Functions *}
-
-instance "fun" :: (finite, countable) countable
-proof
-  obtain xs :: "'a list" where xs: "set xs = UNIV"
-    using finite_list [OF finite_UNIV] ..
-  show "\<exists>to_nat::('a \<Rightarrow> 'b) \<Rightarrow> nat. inj to_nat"
-  proof
-    show "inj (\<lambda>f. to_nat (map f xs))"
-      by (rule injI, simp add: xs fun_eq_iff)
-  qed
-qed
-
-
-subsection {* The Rationals are Countably Infinite *}
-
-definition nat_to_rat_surj :: "nat \<Rightarrow> rat" where
-"nat_to_rat_surj n = (let (a,b) = prod_decode n
-                      in Fract (int_decode a) (int_decode b))"
-
-lemma surj_nat_to_rat_surj: "surj nat_to_rat_surj"
-unfolding surj_def
-proof
-  fix r::rat
-  show "\<exists>n. r = nat_to_rat_surj n"
-  proof (cases r)
-    fix i j assume [simp]: "r = Fract i j" and "j > 0"
-    have "r = (let m = int_encode i; n = int_encode j
-               in nat_to_rat_surj(prod_encode (m,n)))"
-      by (simp add: Let_def nat_to_rat_surj_def)
-    thus "\<exists>n. r = nat_to_rat_surj n" by(auto simp:Let_def)
-  qed
-qed
-
-lemma Rats_eq_range_nat_to_rat_surj: "\<rat> = range nat_to_rat_surj"
-by (simp add: Rats_def surj_nat_to_rat_surj)
-
-context field_char_0
-begin
-
-lemma Rats_eq_range_of_rat_o_nat_to_rat_surj:
-  "\<rat> = range (of_rat o nat_to_rat_surj)"
-using surj_nat_to_rat_surj
-by (auto simp: Rats_def image_def surj_def)
-   (blast intro: arg_cong[where f = of_rat])
-
-lemma surj_of_rat_nat_to_rat_surj:
-  "r\<in>\<rat> \<Longrightarrow> \<exists>n. r = of_rat(nat_to_rat_surj n)"
-by(simp add: Rats_eq_range_of_rat_o_nat_to_rat_surj image_def)
-
-end
-
-instance rat :: countable
-proof
-  show "\<exists>to_nat::rat \<Rightarrow> nat. inj to_nat"
-  proof
-    have "surj nat_to_rat_surj"
-      by (rule surj_nat_to_rat_surj)
-    then show "inj (inv nat_to_rat_surj)"
-      by (rule surj_imp_inj_inv)
-  qed
-qed
-
-
-subsection {* Automatically proving countability of datatypes *}
+subsection {* Automatically proving countability of old-style datatypes *}
 
 inductive finite_item :: "'a Old_Datatype.item \<Rightarrow> bool" where
   undefined: "finite_item undefined"
@@ -268,8 +161,8 @@
 qed
 
 ML {*
-  fun countable_tac ctxt =
-    SUBGOAL (fn (goal, i) =>
+  fun old_countable_tac ctxt =
+    SUBGOAL (fn (goal, _) =>
       let
         val ty_name =
           (case goal of
@@ -279,7 +172,7 @@
         val typedef_thm = #type_definition (snd typedef_info)
         val pred_name =
           (case HOLogic.dest_Trueprop (concl_of typedef_thm) of
-            (typedef $ rep $ abs $ (collect $ Const (n, _))) => n
+            (_ $ _ $ _ $ (_ $ Const (n, _))) => n
           | _ => raise Match)
         val induct_info = Inductive.the_inductive ctxt pred_name
         val pred_names = #names (fst induct_info)
@@ -301,33 +194,124 @@
       end)
 *}
 
-method_setup countable_datatype = {*
-  Scan.succeed (fn ctxt => SIMPLE_METHOD' (countable_tac ctxt))
-*} "prove countable class instances for datatypes"
-
 hide_const (open) finite_item nth_item
 
 
-subsection {* Countable datatypes *}
+subsection {* Automatically proving countability of new-style datatypes *}
+
+ML_file "bnf_lfp_countable.ML"
+
+method_setup countable_datatype = {*
+  Scan.succeed (fn ctxt =>
+    SIMPLE_METHOD (fn st => HEADGOAL (old_countable_tac ctxt) st
+      handle ERROR _ => BNF_LFP_Countable.countable_tac ctxt st))
+*} "prove countable class instances for datatypes"
+
+
+subsection {* More Countable types *}
+
+text {* Naturals *}
 
-(* TODO: automate *)
+instance nat :: countable
+  by (rule countable_classI [of "id"]) simp
+
+text {* Pairs *}
+
+instance prod :: (countable, countable) countable
+  by (rule countable_classI [of "\<lambda>(x, y). prod_encode (to_nat x, to_nat y)"])
+    (auto simp add: prod_encode_eq)
 
-primrec encode_typerep :: "typerep \<Rightarrow> nat" where
-  "encode_typerep (Typerep.Typerep s ts) = prod_encode (to_nat s, to_nat (map encode_typerep ts))"
+text {* Sums *}
+
+instance sum :: (countable, countable) countable
+  by (rule countable_classI [of "(\<lambda>x. case x of Inl a \<Rightarrow> to_nat (False, to_nat a)
+                                     | Inr b \<Rightarrow> to_nat (True, to_nat b))"])
+    (simp split: sum.split_asm)
+
+text {* Integers *}
 
-lemma encode_typerep_injective: "\<forall>u. encode_typerep t = encode_typerep u \<longrightarrow> t = u"
-  apply (induct t)
-  apply (rule allI)
-  apply (case_tac u)
-  apply (auto simp: sum_encode_eq prod_encode_eq elim: list.inj_map_strong[rotated 1])
-  done
+instance int :: countable
+  by (rule countable_classI [of int_encode]) (simp add: int_encode_eq)
+
+text {* Options *}
+
+instance option :: (countable) countable
+  by countable_datatype
+
+text {* Lists *}
+
+instance list :: (countable) countable
+  by countable_datatype
+
+text {* String literals *}
+
+instance String.literal :: countable
+  by (rule countable_classI [of "to_nat o String.explode"]) (auto simp add: explode_inject)
+
+text {* Functions *}
+
+instance "fun" :: (finite, countable) countable
+proof
+  obtain xs :: "'a list" where xs: "set xs = UNIV"
+    using finite_list [OF finite_UNIV] ..
+  show "\<exists>to_nat::('a \<Rightarrow> 'b) \<Rightarrow> nat. inj to_nat"
+  proof
+    show "inj (\<lambda>f. to_nat (map f xs))"
+      by (rule injI, simp add: xs fun_eq_iff)
+  qed
+qed
+
+text {* Typereps *}
 
 instance typerep :: countable
-  apply default
-  apply (unfold inj_on_def ball_UNIV)
-  apply (rule exI)
-  apply (rule allI)
-  apply (rule encode_typerep_injective)
-  done
+  by countable_datatype
+
+
+subsection {* The rationals are countably infinite *}
+
+definition nat_to_rat_surj :: "nat \<Rightarrow> rat" where
+  "nat_to_rat_surj n = (let (a, b) = prod_decode n in Fract (int_decode a) (int_decode b))"
+
+lemma surj_nat_to_rat_surj: "surj nat_to_rat_surj"
+unfolding surj_def
+proof
+  fix r::rat
+  show "\<exists>n. r = nat_to_rat_surj n"
+  proof (cases r)
+    fix i j assume [simp]: "r = Fract i j" and "j > 0"
+    have "r = (let m = int_encode i; n = int_encode j
+               in nat_to_rat_surj(prod_encode (m,n)))"
+      by (simp add: Let_def nat_to_rat_surj_def)
+    thus "\<exists>n. r = nat_to_rat_surj n" by(auto simp:Let_def)
+  qed
+qed
+
+lemma Rats_eq_range_nat_to_rat_surj: "\<rat> = range nat_to_rat_surj"
+  by (simp add: Rats_def surj_nat_to_rat_surj)
+
+context field_char_0
+begin
+
+lemma Rats_eq_range_of_rat_o_nat_to_rat_surj:
+  "\<rat> = range (of_rat o nat_to_rat_surj)"
+  using surj_nat_to_rat_surj
+  by (auto simp: Rats_def image_def surj_def) (blast intro: arg_cong[where f = of_rat])
+
+lemma surj_of_rat_nat_to_rat_surj:
+  "r\<in>\<rat> \<Longrightarrow> \<exists>n. r = of_rat(nat_to_rat_surj n)"
+  by (simp add: Rats_eq_range_of_rat_o_nat_to_rat_surj image_def)
 
 end
+
+instance rat :: countable
+proof
+  show "\<exists>to_nat::rat \<Rightarrow> nat. inj to_nat"
+  proof
+    have "surj nat_to_rat_surj"
+      by (rule surj_nat_to_rat_surj)
+    then show "inj (inv nat_to_rat_surj)"
+      by (rule surj_imp_inj_inv)
+  qed
+qed
+
+end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Library/bnf_lfp_countable.ML	Wed Sep 03 00:31:38 2014 +0200
@@ -0,0 +1,183 @@
+(*  Title:      HOL/Library/bnf_lfp_countable.ML
+    Author:     Jasmin Blanchette, TU Muenchen
+    Copyright   2014
+
+Countability tactic for BNF datatypes.
+*)
+
+signature BNF_LFP_COUNTABLE =
+sig
+  val countable_tac: Proof.context -> tactic
+end;
+
+structure BNF_LFP_Countable : BNF_LFP_COUNTABLE =
+struct
+
+open BNF_FP_Rec_Sugar_Util
+open BNF_Def
+open BNF_Util
+open BNF_Tactics
+open BNF_FP_Util
+open BNF_FP_Def_Sugar
+
+fun nchotomy_tac nchotomy =
+  HEADGOAL (rtac (nchotomy RS @{thm all_reg[rotated]}) THEN'
+    CHANGED_PROP o REPEAT_ALL_NEW (match_tac [allI, impI]) THEN'
+    CHANGED_PROP o REPEAT_ALL_NEW (ematch_tac [exE, disjE]));
+
+fun meta_spec_mp_tac 0 = K all_tac
+  | meta_spec_mp_tac n =
+    dtac meta_spec THEN' meta_spec_mp_tac (n - 1) THEN' dtac meta_mp THEN' atac;
+
+val use_induction_hypothesis_tac =
+  DEEPEN (1, 1000000 (* large number *))
+    (fn depth => meta_spec_mp_tac depth THEN' etac allE THEN' etac impE THEN' atac THEN' atac) 0;
+
+val same_ctr_simps =
+  @{thms sum_encode_eq prod_encode_eq sum.inject prod.inject to_nat_split simp_thms snd_conv};
+
+fun same_ctr_tac ctxt injects recs map_comps' inj_map_strongs' =
+  HEADGOAL (asm_full_simp_tac (ss_only (injects @ recs @ map_comps' @ same_ctr_simps) ctxt) THEN'
+    REPEAT_DETERM o CHANGED_PROP o REPEAT_ALL_NEW (ematch_tac (conjE :: inj_map_strongs'))
+    THEN_ALL_NEW use_induction_hypothesis_tac);
+
+fun distinct_ctrs_tac ctxt recs =
+  HEADGOAL (asm_full_simp_tac (ss_only (recs @
+    @{thms sum_encode_eq sum.inject sum.distinct simp_thms}) ctxt));
+
+fun mk_encode_injective_tac ctxt n nchotomy injects recs map_comps' inj_map_strongs' =
+  let val ks = 1 upto n in
+    EVERY (maps (fn k => nchotomy_tac nchotomy :: map (fn k' =>
+      if k = k' then same_ctr_tac ctxt injects recs map_comps' inj_map_strongs'
+      else distinct_ctrs_tac ctxt recs) ks) ks)
+  end;
+
+fun mk_encode_injectives_tac ctxt ns induct nchotomys injectss recss map_comps' inj_map_strongs' =
+  HEADGOAL (rtac induct) THEN
+  EVERY (map4 (fn n => fn nchotomy => fn injects => fn recs =>
+      mk_encode_injective_tac ctxt n nchotomy injects recs map_comps' inj_map_strongs')
+    ns nchotomys injectss recss);
+
+fun endgame_tac ctxt encode_injectives =
+  unfold_thms_tac ctxt @{thms inj_on_def ball_UNIV} THEN
+  ALLGOALS (rtac exI THEN' rtac allI THEN' resolve_tac encode_injectives);
+
+fun encode_sumN n k t =
+  Balanced_Tree.access {init = t,
+      left = fn t => @{const sum_encode} $ (@{const Inl (nat, nat)} $ t),
+      right = fn t => @{const sum_encode} $ (@{const Inr (nat, nat)} $ t)}
+    n k;
+
+fun encode_tuple [] = @{term "0 :: nat"}
+  | encode_tuple ts =
+    Balanced_Tree.make (fn (t, u) => @{const prod_encode} $ (@{const Pair (nat, nat)} $ u $ t)) ts;
+
+fun mk_to_nat T = Const (@{const_name to_nat}, T --> HOLogic.natT);
+
+fun mk_encode_funs ctxt fpTs ns ctrss0 recs0 =
+  let
+    val thy = Proof_Context.theory_of ctxt;
+
+    val nn = length ns;
+    val recs as rec1 :: _ =
+      map2 (fn fpT => mk_co_rec thy Least_FP fpT (replicate nn HOLogic.natT)) fpTs recs0;
+    val arg_Ts = binder_fun_types (fastype_of rec1);
+    val arg_Tss = Library.unflat ctrss0 arg_Ts;
+
+    fun mk_U (Type (@{type_name prod}, [T1, T2])) =
+        if member (op =) fpTs T1 then T2 else HOLogic.mk_prodT (mk_U T1, mk_U T2)
+      | mk_U (Type (s, Ts)) = Type (s, map mk_U Ts)
+      | mk_U T = T;
+
+    fun mk_nat (j, T) =
+      if T = HOLogic.natT then
+        SOME (Bound j)
+      else if member (op =) fpTs T then
+        NONE
+      else if exists_subtype_in fpTs T then
+        let val U = mk_U T in
+          SOME (mk_to_nat U $ (build_map ctxt [] (snd_const o fst) (T, U) $ Bound j))
+        end
+      else
+        SOME (mk_to_nat T $ Bound j);
+
+    fun mk_arg n (k, arg_T) =
+      let
+        val bound_Ts = rev (binder_types arg_T);
+        val nats = map_filter mk_nat (tag_list 0 bound_Ts);
+      in
+        fold (fn T => fn t => Abs (Name.uu, T, t)) bound_Ts (encode_sumN n k (encode_tuple nats))
+      end;
+
+    val argss = map2 (map o mk_arg) ns (map (tag_list 1) arg_Tss);
+  in
+    map (fn recx => Term.list_comb (recx, flat argss)) recs
+  end;
+
+fun mk_encode_injective_thms _ [] = []
+  | mk_encode_injective_thms ctxt fpT_names0 =
+    let
+      fun not_datatype s = error (quote s ^ " is not a new-style datatype");
+      fun not_mutually_recursive ss =
+        error ("{" ^ commas ss ^ "} is not a set of mutually recursive new-style datatypes");
+
+      fun lfp_sugar_of s =
+        (case fp_sugar_of ctxt s of
+          SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
+        | _ => not_datatype s);
+
+      val fpTs0 as Type (_, var_As) :: _ = #Ts (#fp_res (lfp_sugar_of (hd fpT_names0)));
+      val fpT_names = map (fst o dest_Type) fpTs0;
+
+      val (As_names, _) = Variable.variant_fixes (map (fn TVar ((s, _), _) => s) var_As) ctxt;
+      val As =
+        map2 (fn s => fn TVar (_, S) => TFree (s, union (op =) @{sort countable} S))
+          As_names var_As;
+      val fpTs = map (fn s => Type (s, As)) fpT_names;
+
+      val _ = subset (op =) (fpT_names0, fpT_names) orelse not_mutually_recursive fpT_names0;
+
+      fun mk_conjunct fpT x encode_fun =
+        HOLogic.all_const fpT $ Abs (Name.uu, fpT,
+          HOLogic.mk_imp (HOLogic.mk_eq (encode_fun $ x, encode_fun $ Bound 0),
+            HOLogic.eq_const fpT $ x $ Bound 0));
+
+      val fp_sugars as {fp_nesting_bnfs, common_co_inducts = induct :: _, ...} :: _ =
+        map (the o fp_sugar_of ctxt o fst o dest_Type) fpTs0;
+      val ctr_sugars = map #ctr_sugar fp_sugars;
+
+      val ctrss0 = map #ctrs ctr_sugars;
+      val ns = map length ctrss0;
+      val recs0 = map #co_rec fp_sugars;
+      val nchotomys = map #nchotomy ctr_sugars;
+      val injectss = map #injects ctr_sugars;
+      val rec_thmss = map #co_rec_thms fp_sugars;
+      val map_comps' = map (unfold_thms ctxt @{thms comp_def} o map_comp_of_bnf) fp_nesting_bnfs;
+      val inj_map_strongs' = map (Thm.permute_prems 0 ~1 o inj_map_strong_of_bnf) fp_nesting_bnfs;
+
+      val (xs, names_ctxt) = ctxt |> mk_Frees "x" fpTs;
+
+      val conjuncts = map3 mk_conjunct fpTs xs (mk_encode_funs ctxt fpTs ns ctrss0 recs0);
+      val goal = HOLogic.mk_Trueprop (Balanced_Tree.make HOLogic.mk_conj conjuncts);
+    in
+      Goal.prove_sorry ctxt [] [] goal (fn {context = ctxt, prems = _} =>
+        mk_encode_injectives_tac ctxt ns induct nchotomys injectss rec_thmss map_comps'
+        inj_map_strongs')
+      |> HOLogic.conj_elims
+      |> Proof_Context.export names_ctxt ctxt
+      |> map Thm.close_derivation
+    end;
+
+fun get_countable_goal_typ (@{const Trueprop} $ (Const (@{const_name Ex}, _)
+    $ Abs (_, Type (_, [Type (s, _), _]), Const (@{const_name inj_on}, _) $ Bound 0
+        $ Const (@{const_name top}, _)))) = s
+  | get_countable_goal_typ _ = error "Wrong goal format for countable tactic";
+
+fun core_countable_tac ctxt st =
+  endgame_tac ctxt (mk_encode_injective_thms ctxt (map get_countable_goal_typ (Thm.prems_of st)))
+    st;
+
+fun countable_tac ctxt =
+  TRY (Class.intro_classes_tac []) THEN core_countable_tac ctxt;
+
+end;