src/HOL/Tools/datatype_rep_proofs.ML
changeset 7015 85be09eb136c
parent 6522 2f6cec5c046f
child 7205 dab2be236bfc
--- a/src/HOL/Tools/datatype_rep_proofs.ML	Fri Jul 16 12:09:48 1999 +0200
+++ b/src/HOL/Tools/datatype_rep_proofs.ML	Fri Jul 16 12:14:04 1999 +0200
@@ -12,13 +12,16 @@
 
 *)
 
+val foo = ref [TrueI];
+
 signature DATATYPE_REP_PROOFS =
 sig
   val representation_proofs : bool -> DatatypeAux.datatype_info Symtab.table ->
     string list -> (int * (string * DatatypeAux.dtyp list *
       (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
         (string * mixfix) list -> (string * mixfix) list list -> theory ->
-          theory * thm list list * thm list list * thm
+          theory * thm list list * thm list list * thm list list *
+            DatatypeAux.simproc_dist list * thm
 end;
 
 structure DatatypeRepProofs : DATATYPE_REP_PROOFS =
@@ -43,15 +46,22 @@
 fun representation_proofs flat_names (dt_info : datatype_info Symtab.table)
       new_type_names descr sorts types_syntax constr_syntax thy =
   let
-    val Univ_thy = the (get_thy "Univ" thy);
-    val node_name = Sign.intern_tycon (Theory.sign_of Univ_thy) "node";
-    val [In0_name, In1_name, Scons_name, Leaf_name, Numb_name] =
-      map (Sign.intern_const (Theory.sign_of Univ_thy))
-        ["In0", "In1", "Scons", "Leaf", "Numb"];
+    val Datatype_thy = theory "Datatype";
+    val node_name = Sign.intern_tycon (Theory.sign_of Datatype_thy) "node";
+    val [In0_name, In1_name, Scons_name, Leaf_name, Numb_name, Lim_name,
+      Funs_name, o_name] =
+      map (Sign.intern_const (Theory.sign_of Datatype_thy))
+        ["In0", "In1", "Scons", "Leaf", "Numb", "Lim", "Funs", "op o"];
+
     val [In0_inject, In1_inject, Scons_inject, Leaf_inject, In0_eq, In1_eq,
-      In0_not_In1, In1_not_In0] = map (get_thm Univ_thy)
-        ["In0_inject", "In1_inject", "Scons_inject", "Leaf_inject", "In0_eq",
-         "In1_eq", "In0_not_In1", "In1_not_In0"];
+         In0_not_In1, In1_not_In0, Funs_mono, FunsI, Lim_inject,
+         Funs_inv, FunsD, Funs_rangeE, Funs_nonempty] = map (get_thm Datatype_thy)
+        ["In0_inject", "In1_inject", "Scons_inject", "Leaf_inject", "In0_eq", "In1_eq",
+         "In0_not_In1", "In1_not_In0", "Funs_mono", "FunsI", "Lim_inject",
+         "Funs_inv", "FunsD", "Funs_rangeE", "Funs_nonempty"];
+
+    val Funs_IntE = (Int_lower2 RS Funs_mono RS
+      (Int_lower1 RS Funs_mono RS Int_greatest) RS subsetD) RS IntE;
 
     val descr' = flat descr;
 
@@ -65,19 +75,23 @@
 
     val tyvars = map (fn (_, (_, Ts, _)) => map dest_DtTFree Ts) (hd descr);
     val leafTs' = get_nonrec_types descr' sorts;
-    val unneeded_vars = hd tyvars \\ foldr add_typ_tfree_names (leafTs', []);
+    val branchTs = get_branching_types descr' sorts;
+    val branchT = if null branchTs then HOLogic.unitT
+      else fold_bal (fn (T, U) => Type ("+", [T, U])) branchTs;
+    val unneeded_vars = hd tyvars \\ foldr add_typ_tfree_names (leafTs' @ branchTs, []);
     val leafTs = leafTs' @ (map (fn n => TFree (n, the (assoc (sorts, n)))) unneeded_vars);
     val recTs = get_rec_types descr' sorts;
     val newTs = take (length (hd descr), recTs);
     val oldTs = drop (length (hd descr), recTs);
     val sumT = if null leafTs then HOLogic.unitT
       else fold_bal (fn (T, U) => Type ("+", [T, U])) leafTs;
-    val Univ_elT = HOLogic.mk_setT (Type (node_name, [sumT]));
+    val Univ_elT = HOLogic.mk_setT (Type (node_name, [sumT, branchT]));
     val UnivT = HOLogic.mk_setT Univ_elT;
 
     val In0 = Const (In0_name, Univ_elT --> Univ_elT);
     val In1 = Const (In1_name, Univ_elT --> Univ_elT);
     val Leaf = Const (Leaf_name, sumT --> Univ_elT);
+    val Lim = Const (Lim_name, (branchT --> Univ_elT) --> Univ_elT);
 
     (* make injections needed for embedding types in leaves *)
 
@@ -103,6 +117,25 @@
       else
         foldr1 (HOLogic.mk_binop Scons_name) ts);
 
+    (* function spaces *)
+
+    fun mk_fun_inj T' x =
+      let
+        fun mk_inj T n i =
+          if n = 1 then x else
+          let
+            val n2 = n div 2;
+            val Type (_, [T1, T2]) = T;
+            val sum_case = Const ("sum_case", [T1 --> Univ_elT, T2 --> Univ_elT, T] ---> Univ_elT)
+          in
+            if i <= n2 then
+              sum_case $ (mk_inj T1 n2 i) $ Const ("arbitrary", T2 --> Univ_elT)
+            else
+              sum_case $ Const ("arbitrary", T1 --> Univ_elT) $ mk_inj T2 (n - n2) (i - n2)
+          end
+      in mk_inj branchT (length branchTs) (1 + find_index_eq T' branchTs)
+      end;
+
     (************** generate introduction rules for representing set **********)
 
     val _ = message "Constructing representing sets ...";
@@ -116,6 +149,14 @@
               in (j + 1, (HOLogic.mk_mem (free_t,
                 Const (nth_elem (k, rep_set_names), UnivT)))::prems, free_t::ts)
               end
+          | mk_prem (DtType ("fun", [T, DtRec k]), (j, prems, ts)) =
+              let val T' = typ_of_dtyp descr' sorts T;
+                  val free_t = mk_Free "x" (T' --> Univ_elT) j
+              in (j + 1, (HOLogic.mk_mem (free_t,
+                Const (Funs_name, UnivT --> HOLogic.mk_setT (T' --> Univ_elT)) $
+                  Const (nth_elem (k, rep_set_names), UnivT)))::prems,
+                    Lim $ mk_fun_inj T' free_t::ts)
+              end
           | mk_prem (dt, (j, prems, ts)) =
               let val T = typ_of_dtyp descr' sorts dt
               in (j + 1, prems, (Leaf $ mk_inj T (mk_Free "x" T j))::ts)
@@ -136,16 +177,17 @@
     val (thy2, {raw_induct = rep_induct, intrs = rep_intrs, ...}) =
       setmp InductivePackage.quiet_mode (!quiet_mode)
         (InductivePackage.add_inductive_i false true big_rec_name false true false
-           consts [] (map (fn x => (("", x), [])) intr_ts) [] []) thy1;
+           consts [] (map (fn x => (("", x), [])) intr_ts) [Funs_mono] []) thy1;
 
     (********************************* typedef ********************************)
 
     val thy3 = add_path flat_names big_name (foldl (fn (thy, ((((name, mx), tvs), c), name')) =>
       setmp TypedefPackage.quiet_mode true
         (TypedefPackage.add_typedef_i_no_def name' (name, tvs, mx) c [] []
-          (Some (QUIET_BREADTH_FIRST (has_fewer_prems 1) (resolve_tac rep_intrs 1)))) thy)
-            (parent_path flat_names thy2, types_syntax ~~ tyvars ~~ (take (length newTs, consts)) ~~
-              new_type_names));
+          (Some (QUIET_BREADTH_FIRST (has_fewer_prems 1)
+            (resolve_tac (Funs_nonempty::rep_intrs) 1)))) thy)
+              (parent_path flat_names thy2, types_syntax ~~ tyvars ~~
+                (take (length newTs, consts)) ~~ new_type_names));
 
     (*********************** definition of constructors ***********************)
 
@@ -171,6 +213,13 @@
           in (case dt of
               DtRec m => (j + 1, free_t::l_args, (Const (nth_elem (m, all_rep_names),
                 T --> Univ_elT) $ free_t)::r_args)
+            | DtType ("fun", [T', DtRec m]) =>
+                let val ([T''], T''') = strip_type T
+                in (j + 1, free_t::l_args, (Lim $ mk_fun_inj T''
+                  (Const (o_name, [T''' --> Univ_elT, T, T''] ---> Univ_elT) $
+                    Const (nth_elem (m, all_rep_names), T''' --> Univ_elT) $ free_t))::r_args)
+                end
+
             | _ => (j + 1, free_t::l_args, (Leaf $ mk_inj T free_t)::r_args))
           end;
 
@@ -200,8 +249,8 @@
         val sg = Theory.sign_of thy;
         val rep_const = cterm_of sg
           (Const (Sign.intern_const sg ("Rep_" ^ tname), T --> Univ_elT));
-        val cong' = cterm_instantiate [(cterm_of sg cong_f, rep_const)] arg_cong;
-        val dist = cterm_instantiate [(cterm_of sg distinct_f, rep_const)] distinct_lemma;
+        val cong' = standard (cterm_instantiate [(cterm_of sg cong_f, rep_const)] arg_cong);
+        val dist = standard (cterm_instantiate [(cterm_of sg distinct_f, rep_const)] distinct_lemma);
         val (thy', defs', eqns', _) = foldl ((make_constr_def tname T) (length constrs))
           ((add_path flat_names tname thy, defs, [], 1), constrs ~~ constr_syntax)
       in
@@ -282,23 +331,34 @@
         val rep_const = Const (rep_name, T --> Univ_elT);
         val constr = Const (cname, argTs ---> T);
 
-        fun process_arg ks' ((i2, i2', ts), dt) =
+        fun process_arg ks' ((i2, i2', ts, Ts), dt) =
           let val T' = typ_of_dtyp descr' sorts dt
           in (case dt of
               DtRec j => if j mem ks' then
-                  (i2 + 1, i2' + 1, ts @ [mk_Free "y" Univ_elT i2'])
+                  (i2 + 1, i2' + 1, ts @ [mk_Free "y" Univ_elT i2'], Ts @ [Univ_elT])
                 else
                   (i2 + 1, i2', ts @ [Const (nth_elem (j, all_rep_names),
-                    T' --> Univ_elT) $ mk_Free "x" T' i2])
-            | _ => (i2 + 1, i2', ts @ [Leaf $ mk_inj T' (mk_Free "x" T' i2)]))
+                    T' --> Univ_elT) $ mk_Free "x" T' i2], Ts)
+            | (DtType ("fun", [_, DtRec j])) =>
+                let val ([T''], T''') = strip_type T'
+                in if j mem ks' then
+                    (i2 + 1, i2' + 1, ts @ [Lim $ mk_fun_inj T''
+                      (mk_Free "y" (T'' --> Univ_elT) i2')], Ts @ [T'' --> Univ_elT])
+                  else
+                    (i2 + 1, i2', ts @ [Lim $ mk_fun_inj T''
+                      (Const (o_name, [T''' --> Univ_elT, T', T''] ---> Univ_elT) $
+                        Const (nth_elem (j, all_rep_names), T''' --> Univ_elT) $
+                          mk_Free "x" T' i2)], Ts)
+                end
+            | _ => (i2 + 1, i2', ts @ [Leaf $ mk_inj T' (mk_Free "x" T' i2)], Ts))
           end;
 
-        val (i2, i2', ts) = foldl (process_arg ks) ((1, 1, []), cargs);
+        val (i2, i2', ts, Ts) = foldl (process_arg ks) ((1, 1, [], []), cargs);
         val xs = map (uncurry (mk_Free "x")) (argTs ~~ (1 upto (i2 - 1)));
-        val ys = map (mk_Free "y" Univ_elT) (1 upto (i2' - 1));
+        val ys = map (uncurry (mk_Free "y")) (Ts ~~ (1 upto (i2' - 1)));
         val f = list_abs_free (map dest_Free (xs @ ys), mk_univ_inj ts n i);
 
-        val (_, _, ts') = foldl (process_arg []) ((1, 1, []), cargs);
+        val (_, _, ts', _) = foldl (process_arg []) ((1, 1, [], []), cargs);
         val eqn = HOLogic.mk_Trueprop (HOLogic.mk_eq
           (rep_const $ list_comb (constr, xs), mk_univ_inj ts' n i))
 
@@ -340,6 +400,21 @@
 
     (* prove isomorphism properties *)
 
+    fun mk_funs_inv thm =
+      let
+        val [_, t] = prems_of Funs_inv;
+        val [_ $ (_ $ _ $ R)] = Logic.strip_assums_hyp t;
+        val _ $ (_ $ (r $ (a $ _)) $ _) = Logic.strip_assums_concl t;
+        val [_ $ (_ $ _ $ R')] = prems_of thm;
+        val _ $ (_ $ (r' $ (a' $ _)) $ _) = concl_of thm;
+        val inv' = cterm_instantiate (map 
+          ((pairself (cterm_of (sign_of_thm thm))) o
+           (apsnd (map_term_types (incr_tvar 1))))
+             [(R, R'), (r, r'), (a, a')]) Funs_inv
+      in
+        rule_by_tactic (atac 2) (thm RSN (2, inv'))
+      end;
+
     (* prove  x : dt_rep_set_i --> x : range dt_Rep_i *)
 
     fun mk_iso_t (((set_name, iso_name), i), T) =
@@ -355,8 +430,6 @@
     val iso_t = HOLogic.mk_Trueprop (mk_conj (map mk_iso_t
       (rep_set_names ~~ all_rep_names ~~ (0 upto (length descr' - 1)) ~~ recTs)));
 
-    val newT_Abs_inverse_thms = map (fn (iso, _, _) => iso RS subst) newT_iso_axms;
-
     (* all the theorems are proved by one single simultaneous induction *)
 
     val iso_thms = if length descr = 1 then [] else
@@ -365,14 +438,19 @@
            [indtac rep_induct 1,
             REPEAT (rtac TrueI 1),
             REPEAT (EVERY
-              [REPEAT (etac rangeE 1),
-               REPEAT (eresolve_tac newT_Abs_inverse_thms 1),
+              [rewrite_goals_tac [mk_meta_eq Collect_mem_eq],
+               REPEAT (etac Funs_IntE 1),
+               REPEAT (eresolve_tac [rangeE, Funs_rangeE] 1),
+               REPEAT (eresolve_tac (map (fn (iso, _, _) => iso RS subst) newT_iso_axms @
+                 map (fn (iso, _, _) => mk_funs_inv iso RS subst) newT_iso_axms) 1),
                TRY (hyp_subst_tac 1),
                rtac (sym RS range_eqI) 1,
                resolve_tac iso_char_thms 1])])));
 
-    val Abs_inverse_thms = newT_Abs_inverse_thms @ (map (fn r =>
-      r RS mp RS f_inv_f RS subst) iso_thms);
+    val Abs_inverse_thms' = (map #1 newT_iso_axms) @ map (fn r => r RS mp RS f_inv_f) iso_thms;
+
+    val Abs_inverse_thms = map (fn r => r RS subst) (Abs_inverse_thms' @
+      map mk_funs_inv Abs_inverse_thms');
 
     (* prove  inj dt_Rep_i  and  dt_Rep_i x : dt_rep_set_i *)
 
@@ -395,7 +473,7 @@
         val (ind_concl1, ind_concl2) = ListPair.unzip (map mk_ind_concl ds);
 
         val rewrites = map mk_meta_eq iso_char_thms;
-        val inj_thms' = map (fn r => r RS injD) inj_thms;
+        val inj_thms' = flat (map (fn r => [r RS injD, r RS inj_o]) inj_thms);
 
         val inj_thm = prove_goalw_cterm [] (cterm_of (Theory.sign_of thy5)
           (HOLogic.mk_Trueprop (mk_conj ind_concl1))) (fn _ =>
@@ -411,8 +489,9 @@
                    ORELSE (EVERY
                      [REPEAT (etac Scons_inject 1),
                       REPEAT (dresolve_tac
-                        (inj_thms' @ [Leaf_inject, Inl_inject, Inr_inject]) 1),
-                      REPEAT (EVERY [etac allE 1, dtac mp 1, atac 1]),
+                        (inj_thms' @ [Leaf_inject, Lim_inject, Inl_inject, Inr_inject]) 1),
+                      REPEAT ((EVERY [etac allE 1, dtac mp 1, atac 1]) ORELSE
+                              (dtac inj_fun_lemma 1 THEN atac 1)),
                       TRY (hyp_subst_tac 1),
                       rtac refl 1])])])]);
 
@@ -425,11 +504,11 @@
 	       (HOLogic.mk_Trueprop (mk_conj ind_concl2)))
 	      (fn _ =>
 	       [indtac induction 1,
-		rewrite_goals_tac rewrites,
+		rewrite_goals_tac (o_def :: rewrites),
 		REPEAT (EVERY
 			[resolve_tac rep_intrs 1,
-			 REPEAT ((atac 1) ORELSE
-				 (resolve_tac elem_thms 1))])]);
+			 REPEAT (FIRST [atac 1, etac spec 1,
+				 resolve_tac (FunsI :: elem_thms) 1])])]);
 
       in (inj_thms @ inj_thms'', elem_thms @ (split_conj_thm elem_thm))
       end;
@@ -446,19 +525,18 @@
     fun prove_constr_rep_thm eqn =
       let
         val inj_thms = map (fn (r, _) => r RS inj_onD) newT_iso_inj_thms;
-        val rewrites = constr_defs @ (map (mk_meta_eq o #2) newT_iso_axms)
+        val rewrites = o_def :: constr_defs @ (map (mk_meta_eq o #2) newT_iso_axms)
       in prove_goalw_cterm [] (cterm_of (Theory.sign_of thy5) eqn) (fn _ =>
         [resolve_tac inj_thms 1,
          rewrite_goals_tac rewrites,
          rtac refl 1,
          resolve_tac rep_intrs 2,
-         REPEAT (resolve_tac iso_elem_thms 1)])
+         REPEAT (resolve_tac (FunsI :: iso_elem_thms) 1)])
       end;
 
     (*--------------------------------------------------------------*)
     (* constr_rep_thms and rep_congs are used to prove distinctness *)
-    (* of constructors internally.                                  *)
-    (* the external version uses dt_case which is not defined yet   *)
+    (* of constructors.                                             *)
     (*--------------------------------------------------------------*)
 
     val constr_rep_thms = map (map prove_constr_rep_thm) constr_rep_eqns;
@@ -467,27 +545,45 @@
       dist_lemma::(rep_thms @ [In0_eq, In1_eq, In0_not_In1, In1_not_In0]))
         (constr_rep_thms ~~ dist_lemmas);
 
+    fun prove_distinct_thms (_, []) = []
+      | prove_distinct_thms (dist_rewrites', t::_::ts) =
+          let
+            val dist_thm = prove_goalw_cterm [] (cterm_of (Theory.sign_of thy5) t) (fn _ =>
+              [simp_tac (HOL_ss addsimps dist_rewrites') 1])
+          in dist_thm::(standard (dist_thm RS not_sym))::
+            (prove_distinct_thms (dist_rewrites', ts))
+          end;
+
+    val distinct_thms = map prove_distinct_thms (dist_rewrites ~~
+      DatatypeProp.make_distincts new_type_names descr sorts thy5);
+
+    val simproc_dists = map (fn ((((_, (_, _, constrs)), rep_thms), congr), dists) =>
+      if length constrs < !DatatypeProp.dtK then FewConstrs dists
+      else ManyConstrs (congr, HOL_basic_ss addsimps rep_thms)) (hd descr ~~
+        constr_rep_thms ~~ rep_congs ~~ distinct_thms);
+
     (* prove injectivity of constructors *)
 
     fun prove_constr_inj_thm rep_thms t =
-      let val inj_thms = Scons_inject::(map make_elim
+      let val inj_thms = Scons_inject::sum_case_inject::(map make_elim
         ((map (fn r => r RS injD) iso_inj_thms) @
-          [In0_inject, In1_inject, Leaf_inject, Inl_inject, Inr_inject]))
+          [In0_inject, In1_inject, Leaf_inject, Inl_inject, Inr_inject, Lim_inject]))
       in prove_goalw_cterm [] (cterm_of (Theory.sign_of thy5) t) (fn _ =>
         [rtac iffI 1,
          REPEAT (etac conjE 2), hyp_subst_tac 2, rtac refl 2,
          dresolve_tac rep_congs 1, dtac box_equals 1,
-         REPEAT (resolve_tac rep_thms 1),
+         REPEAT (resolve_tac rep_thms 1), rewrite_goals_tac [o_def],
          REPEAT (eresolve_tac inj_thms 1),
-         hyp_subst_tac 1,
-         REPEAT (resolve_tac [conjI, refl] 1)])
+         REPEAT (ares_tac [conjI] 1 ORELSE (EVERY [rtac ext 1, dtac fun_cong 1,
+                  eresolve_tac inj_thms 1, atac 1]))])
       end;
 
     val constr_inject = map (fn (ts, thms) => map (prove_constr_inj_thm thms) ts)
       ((DatatypeProp.make_injs descr sorts) ~~ constr_rep_thms);
 
-    val thy6 = store_thmss "inject" new_type_names
-      constr_inject (parent_path flat_names thy5);
+    val thy6 = thy5 |> parent_path flat_names |>
+      store_thmss "inject" new_type_names constr_inject |>
+      store_thmss "distinct" new_type_names distinct_thms;
 
     (*************************** induction theorem ****************************)
 
@@ -538,17 +634,18 @@
       (DatatypeProp.make_ind descr sorts)) (fn prems =>
         [rtac indrule_lemma' 1, indtac rep_induct 1,
          EVERY (map (fn (prem, r) => (EVERY
-           [REPEAT (eresolve_tac Abs_inverse_thms 1),
+           [REPEAT (eresolve_tac (Funs_IntE::Abs_inverse_thms) 1),
             simp_tac (HOL_basic_ss addsimps ((symmetric r)::Rep_inverse_thms')) 1,
-            DEPTH_SOLVE_1 (ares_tac [prem] 1)]))
-              (prems ~~ (constr_defs @ (map mk_meta_eq iso_char_thms))))]);
+            DEPTH_SOLVE_1 (ares_tac [prem] 1 ORELSE (EVERY [rewrite_goals_tac [o_def],
+              rtac allI 1, dtac FunsD 1, etac CollectD 1]))]))
+                (prems ~~ (constr_defs @ (map mk_meta_eq iso_char_thms))))]);
 
     val thy7 = thy6 |>
       Theory.add_path big_name |>
       PureThy.add_thms [(("induct", dt_induct), [])] |>
       Theory.parent_path;
 
-  in (thy7, constr_inject, dist_rewrites, dt_induct)
+  in (thy7, constr_inject, distinct_thms, dist_rewrites, simproc_dists, dt_induct)
   end;
 
 end;