src/HOL/Tools/datatype_abs_proofs.ML
changeset 5177 0d3a168e4d44
child 5303 22029546d109
equal deleted inserted replaced
5176:36d38be7e814 5177:0d3a168e4d44
       
     1 (*  Title:      HOL/Tools/datatype_abs_proofs.ML
       
     2     ID:         $Id$
       
     3     Author:     Stefan Berghofer
       
     4     Copyright   1998  TU Muenchen
       
     5 
       
     6 Proofs and defintions independent of concrete representation
       
     7 of datatypes  (i.e. requiring only abstract properties such as
       
     8 injectivity / distinctness of constructors and induction)
       
     9 
       
    10  - case distinction (exhaustion) theorems
       
    11  - characteristic equations for primrec combinators
       
    12  - characteristic equations for case combinators
       
    13  - distinctness of constructors (external version)
       
    14  - equations for splitting "P (case ...)" expressions
       
    15  - datatype size function
       
    16  - "nchotomy" and "case_cong" theorems for TFL
       
    17 
       
    18 *)
       
    19 
       
    20 signature DATATYPE_ABS_PROOFS =
       
    21 sig
       
    22   val prove_casedist_thms : string list -> (int * (string * DatatypeAux.dtyp list *
       
    23     (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
       
    24       thm -> theory -> theory * thm list
       
    25   val prove_primrec_thms : string list -> (int * (string * DatatypeAux.dtyp list *
       
    26     (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
       
    27       DatatypeAux.datatype_info Symtab.table -> thm list list -> thm list list ->
       
    28         thm -> theory -> theory * string list * thm list
       
    29   val prove_case_thms : string list -> (int * (string * DatatypeAux.dtyp list *
       
    30     (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
       
    31       string list -> thm list -> theory -> theory * string list * thm list list
       
    32   val prove_distinctness_thms : string list -> (int * (string * DatatypeAux.dtyp list *
       
    33     (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
       
    34       thm list list -> thm list list -> theory -> theory * thm list list
       
    35   val prove_split_thms : string list -> (int * (string * DatatypeAux.dtyp list *
       
    36     (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
       
    37       thm list list -> thm list list -> thm list -> thm list list -> theory ->
       
    38         theory * (thm * thm) list
       
    39   val prove_size_thms : string list -> (int * (string * DatatypeAux.dtyp list *
       
    40     (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
       
    41       string list -> thm list -> theory -> theory * thm list
       
    42   val prove_nchotomys : string list -> (int * (string * DatatypeAux.dtyp list *
       
    43     (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
       
    44       thm list -> theory -> theory * thm list
       
    45   val prove_case_congs : string list -> (int * (string * DatatypeAux.dtyp list *
       
    46     (string * DatatypeAux.dtyp list) list)) list list -> (string * sort) list ->
       
    47       thm list -> thm list list -> theory -> theory * thm list
       
    48 end;
       
    49 
       
    50 structure DatatypeAbsProofs : DATATYPE_ABS_PROOFS =
       
    51 struct
       
    52 
       
    53 open DatatypeAux;
       
    54 
       
    55 val thin = read_instantiate_sg (sign_of Set.thy) [("V", "?X : ?Y")] thin_rl;
       
    56 
       
    57 val (_ $ (_ $ (_ $ (distinct_f $ _) $ _))) = hd (prems_of distinct_lemma);
       
    58 
       
    59 (************************ case distinction theorems ***************************)
       
    60 
       
    61 fun prove_casedist_thms new_type_names descr sorts induct thy =
       
    62   let
       
    63     val _ = writeln "Proving case distinction theorems...";
       
    64 
       
    65     val descr' = flat descr;
       
    66     val recTs = get_rec_types descr' sorts;
       
    67     val newTs = take (length (hd descr), recTs);
       
    68 
       
    69     val induct_Ps = map head_of (dest_conj (HOLogic.dest_Trueprop (concl_of induct)));
       
    70 
       
    71     fun prove_casedist_thm ((i, t), T) =
       
    72       let
       
    73         val dummyPs = map (fn (Var (_, Type (_, [T', T'']))) =>
       
    74           Abs ("z", T', Const ("True", T''))) induct_Ps;
       
    75         val P = Abs ("z", T, HOLogic.imp $ HOLogic.mk_eq (Var (("a", 0), T), Bound 0) $
       
    76           Var (("P", 0), HOLogic.boolT))
       
    77         val insts = take (i, dummyPs) @ (P::(drop (i + 1, dummyPs)));
       
    78         val cert = cterm_of (sign_of thy);
       
    79         val insts' = (map cert induct_Ps) ~~ (map cert insts);
       
    80         val induct' = refl RS ((nth_elem (i,
       
    81           split_conj_thm (cterm_instantiate insts' induct))) RSN (2, rev_mp))
       
    82 
       
    83       in prove_goalw_cterm [] (cert t) (fn prems =>
       
    84         [rtac induct' 1,
       
    85          REPEAT (rtac TrueI 1),
       
    86          REPEAT ((rtac impI 1) THEN (eresolve_tac prems 1)),
       
    87          REPEAT (rtac TrueI 1)])
       
    88       end;
       
    89 
       
    90     val casedist_thms = map prove_casedist_thm ((0 upto (length newTs - 1)) ~~
       
    91       (DatatypeProp.make_casedists descr sorts) ~~ newTs)
       
    92 
       
    93   in
       
    94     (store_thms "exhaust" new_type_names casedist_thms thy, casedist_thms)
       
    95   end;
       
    96 
       
    97 (*************************** primrec combinators ******************************)
       
    98 
       
    99 fun prove_primrec_thms new_type_names descr sorts
       
   100     (dt_info : datatype_info Symtab.table) constr_inject dist_rewrites induct thy =
       
   101   let
       
   102     val _ = writeln "Constructing primrec combinators...";
       
   103 
       
   104     val descr' = flat descr;
       
   105     val recTs = get_rec_types descr' sorts;
       
   106     val newTs = take (length (hd descr), recTs);
       
   107 
       
   108     val induct_Ps = map head_of (dest_conj (HOLogic.dest_Trueprop (concl_of induct)));
       
   109 
       
   110     val big_rec_name' = (space_implode "_" new_type_names) ^ "_rec_set";
       
   111     val rec_set_names = map (Sign.full_name (sign_of thy))
       
   112       (if length descr' = 1 then [big_rec_name'] else
       
   113         (map ((curry (op ^) (big_rec_name' ^ "_")) o string_of_int)
       
   114           (1 upto (length descr'))));
       
   115 
       
   116     val rec_result_Ts = map (fn (i, _) =>
       
   117       TFree ("'t" ^ (string_of_int (i + 1)), HOLogic.termS)) descr';
       
   118 
       
   119     val reccomb_fn_Ts = flat (map (fn (i, (_, _, constrs)) =>
       
   120       map (fn (_, cargs) =>
       
   121         let
       
   122           val recs = filter is_rec_type cargs;
       
   123           val argTs = (map (typ_of_dtyp descr' sorts) cargs) @
       
   124             (map (fn r => nth_elem (dest_DtRec r, rec_result_Ts)) recs)
       
   125         in argTs ---> nth_elem (i, rec_result_Ts)
       
   126         end) constrs) descr');
       
   127 
       
   128     val rec_set_Ts = map (fn (T1, T2) => reccomb_fn_Ts ---> HOLogic.mk_setT
       
   129       (HOLogic.mk_prodT (T1, T2))) (recTs ~~ rec_result_Ts);
       
   130 
       
   131     val rec_fns = map (uncurry (mk_Free "f"))
       
   132       (reccomb_fn_Ts ~~ (1 upto (length reccomb_fn_Ts)));
       
   133     val rec_sets = map (fn c => list_comb (Const c, rec_fns))
       
   134       (rec_set_names ~~ rec_set_Ts);
       
   135 
       
   136     (* introduction rules for graph of primrec function *)
       
   137 
       
   138     fun make_rec_intr T set_name ((rec_intr_ts, l), (cname, cargs)) =
       
   139       let
       
   140         fun mk_prem (dt, (j, k, prems, t1s, t2s)) =
       
   141           let
       
   142             val T = typ_of_dtyp descr' sorts dt;
       
   143             val free1 = mk_Free "x" T j
       
   144           in (case dt of
       
   145              DtRec m =>
       
   146                let val free2 = mk_Free "y" (nth_elem (m, rec_result_Ts)) k
       
   147                in (j + 1, k + 1, (HOLogic.mk_Trueprop (HOLogic.mk_mem
       
   148                  (HOLogic.mk_prod (free1, free2), nth_elem (m, rec_sets))))::prems,
       
   149                    free1::t1s, free2::t2s)
       
   150                end
       
   151            | _ => (j + 1, k, prems, free1::t1s, t2s))
       
   152           end;
       
   153 
       
   154         val Ts = map (typ_of_dtyp descr' sorts) cargs;
       
   155         val (_, _, prems, t1s, t2s) = foldr mk_prem (cargs, (1, 1, [], [], []))
       
   156 
       
   157       in (rec_intr_ts @ [Logic.list_implies (prems, HOLogic.mk_Trueprop (HOLogic.mk_mem
       
   158         (HOLogic.mk_prod (list_comb (Const (cname, Ts ---> T), t1s),
       
   159           list_comb (nth_elem (l, rec_fns), t1s @ t2s)), set_name)))], l + 1)
       
   160       end;
       
   161 
       
   162     val (rec_intr_ts, _) = foldl (fn (x, ((d, T), set_name)) =>
       
   163       foldl (make_rec_intr T set_name) (x, #3 (snd d)))
       
   164         (([], 0), descr' ~~ recTs ~~ rec_sets);
       
   165 
       
   166     val (thy1, {intrs = rec_intrs, elims = rec_elims, ...}) =
       
   167       InductivePackage.add_inductive_i false true big_rec_name' false false true
       
   168         rec_sets rec_intr_ts [] [] thy;
       
   169 
       
   170     (* prove uniqueness and termination of primrec combinators *)
       
   171 
       
   172     val _ = writeln "Proving termination and uniqueness of primrec functions...";
       
   173 
       
   174     fun mk_unique_tac ((tac, intrs), ((((i, (tname, _, constrs)), elim), T), T')) =
       
   175       let
       
   176         val distinct_tac = (etac Pair_inject 1) THEN
       
   177           (if i < length newTs then
       
   178              full_simp_tac (HOL_ss addsimps (nth_elem (i, dist_rewrites))) 1
       
   179            else full_simp_tac (HOL_ss addsimps
       
   180              ((#distinct (the (Symtab.lookup (dt_info, tname)))) @
       
   181                [Suc_Suc_eq, Suc_not_Zero, Zero_not_Suc])) 1);
       
   182 
       
   183         val inject = map (fn r => r RS iffD1)
       
   184           (if i < length newTs then nth_elem (i, constr_inject)
       
   185             else #inject (the (Symtab.lookup (dt_info, tname))));
       
   186 
       
   187         fun mk_unique_constr_tac n ((tac, intr::intrs, j), (cname, cargs)) =
       
   188           let
       
   189             val k = length (filter is_rec_type cargs)
       
   190 
       
   191           in (EVERY [DETERM tac,
       
   192                 REPEAT (etac ex1E 1), rtac ex1I 1,
       
   193                 DEPTH_SOLVE_1 (ares_tac [intr] 1),
       
   194                 REPEAT_DETERM_N k (etac thin 1),
       
   195                 etac elim 1,
       
   196                 REPEAT_DETERM_N j distinct_tac,
       
   197                 etac Pair_inject 1, TRY (dresolve_tac inject 1),
       
   198                 REPEAT (etac conjE 1), hyp_subst_tac 1,
       
   199                 REPEAT (etac allE 1),
       
   200                 REPEAT (dtac mp 1 THEN atac 1),
       
   201                 TRY (hyp_subst_tac 1),
       
   202                 rtac refl 1,
       
   203                 REPEAT_DETERM_N (n - j - 1) distinct_tac],
       
   204               intrs, j + 1)
       
   205           end;
       
   206 
       
   207         val (tac', intrs', _) = foldl (mk_unique_constr_tac (length constrs))
       
   208           ((tac, intrs, 0), constrs);
       
   209 
       
   210       in (tac', intrs') end;
       
   211 
       
   212     val rec_unique_thms =
       
   213       let
       
   214         val rec_unique_ts = map (fn (((set_t, T1), T2), i) =>
       
   215           Const ("Ex1", (T2 --> HOLogic.boolT) --> HOLogic.boolT) $
       
   216             absfree ("y", T2, HOLogic.mk_mem (HOLogic.mk_prod
       
   217               (mk_Free "x" T1 i, Free ("y", T2)), set_t)))
       
   218                 (rec_sets ~~ recTs ~~ rec_result_Ts ~~ (1 upto length recTs));
       
   219         val cert = cterm_of (sign_of thy1)
       
   220         val insts = map (fn ((i, T), t) => absfree ("x" ^ (string_of_int i), T, t))
       
   221           ((1 upto length recTs) ~~ recTs ~~ rec_unique_ts);
       
   222         val induct' = cterm_instantiate ((map cert induct_Ps) ~~
       
   223           (map cert insts)) induct;
       
   224         val (tac, _) = foldl mk_unique_tac
       
   225           ((rtac induct' 1, rec_intrs), descr' ~~ rec_elims ~~ recTs ~~ rec_result_Ts)
       
   226 
       
   227       in split_conj_thm (prove_goalw_cterm []
       
   228         (cert (HOLogic.mk_Trueprop (mk_conj rec_unique_ts))) (K [tac]))
       
   229       end;
       
   230 
       
   231     val rec_total_thms = map (fn r =>
       
   232       r RS ex1_implies_ex RS (select_eq_Ex RS iffD2)) rec_unique_thms;
       
   233 
       
   234     (* define primrec combinators *)
       
   235 
       
   236     val big_reccomb_name = (space_implode "_" new_type_names) ^ "_rec";
       
   237     val reccomb_names = map (Sign.full_name (sign_of thy1))
       
   238       (if length descr' = 1 then [big_reccomb_name] else
       
   239         (map ((curry (op ^) (big_reccomb_name ^ "_")) o string_of_int)
       
   240           (1 upto (length descr'))));
       
   241     val reccombs = map (fn ((name, T), T') => list_comb
       
   242       (Const (name, reccomb_fn_Ts @ [T] ---> T'), rec_fns))
       
   243         (reccomb_names ~~ recTs ~~ rec_result_Ts);
       
   244 
       
   245     val thy2 = thy1 |>
       
   246       Theory.add_consts_i (map (fn ((name, T), T') =>
       
   247         (Sign.base_name name, reccomb_fn_Ts @ [T] ---> T', NoSyn))
       
   248           (reccomb_names ~~ recTs ~~ rec_result_Ts)) |>
       
   249       Theory.add_defs_i (map (fn ((((name, comb), set), T), T') =>
       
   250         ((Sign.base_name name) ^ "_def", Logic.mk_equals
       
   251           (comb $ Free ("x", T),
       
   252            Const ("Eps", (T' --> HOLogic.boolT) --> T') $ absfree ("y", T',
       
   253              HOLogic.mk_mem (HOLogic.mk_prod (Free ("x", T), Free ("y", T')), set)))))
       
   254                (reccomb_names ~~ reccombs ~~ rec_sets ~~ recTs ~~ rec_result_Ts));
       
   255 
       
   256     val reccomb_defs = map ((get_def thy2) o Sign.base_name) reccomb_names;
       
   257 
       
   258     (* prove characteristic equations for primrec combinators *)
       
   259 
       
   260     val _ = writeln "Proving characteristic theorems for primrec combinators..."
       
   261 
       
   262     val rec_thms = map (fn t => prove_goalw_cterm reccomb_defs
       
   263       (cterm_of (sign_of thy2) t) (fn _ =>
       
   264         [rtac select1_equality 1,
       
   265          resolve_tac rec_unique_thms 1,
       
   266          resolve_tac rec_intrs 1,
       
   267          REPEAT (resolve_tac rec_total_thms 1)]))
       
   268            (DatatypeProp.make_primrecs new_type_names descr sorts thy2)
       
   269 
       
   270   in
       
   271     (PureThy.add_tthmss [(("recs", map Attribute.tthm_of rec_thms), [])] thy2,
       
   272      reccomb_names, rec_thms)
       
   273   end;
       
   274 
       
   275 (***************************** case combinators *******************************)
       
   276 
       
   277 fun prove_case_thms new_type_names descr sorts reccomb_names primrec_thms thy =
       
   278   let
       
   279     val _ = writeln "Proving characteristic theorems for case combinators...";
       
   280 
       
   281     val descr' = flat descr;
       
   282     val recTs = get_rec_types descr' sorts;
       
   283     val newTs = take (length (hd descr), recTs);
       
   284 
       
   285     val case_dummy_fns = map (fn (_, (_, _, constrs)) => map (fn (_, cargs) =>
       
   286       let
       
   287         val Ts = map (typ_of_dtyp descr' sorts) cargs;
       
   288         val free = TFree ("'t", HOLogic.termS);
       
   289         val Ts' = replicate (length (filter is_rec_type cargs)) free
       
   290       in Const ("arbitrary", Ts @ Ts' ---> free)
       
   291       end) constrs) descr';
       
   292 
       
   293     val case_names = map (fn s =>
       
   294       Sign.full_name (sign_of thy) (s ^ "_case")) new_type_names;
       
   295 
       
   296     (* define case combinators via primrec combinators *)
       
   297 
       
   298     val (case_defs, thy2) = foldl (fn ((defs, thy),
       
   299       ((((i, (_, _, constrs)), T), name), recname)) =>
       
   300         let
       
   301           val T' = TFree ("'t", HOLogic.termS);
       
   302 
       
   303           val (fns1, fns2) = ListPair.unzip (map (fn ((_, cargs), j) =>
       
   304             let
       
   305               val Ts = map (typ_of_dtyp descr' sorts) cargs;
       
   306               val Ts' = Ts @ (replicate (length (filter is_rec_type cargs)) T');
       
   307               val frees' = map (uncurry (mk_Free "x")) (Ts' ~~ (1 upto length Ts'));
       
   308               val frees = take (length cargs, frees');
       
   309               val free = mk_Free "f" (Ts ---> T') j
       
   310             in
       
   311              (free, list_abs_free (map dest_Free frees',
       
   312                list_comb (free, frees)))
       
   313             end) (constrs ~~ (1 upto length constrs)));
       
   314 
       
   315           val caseT = (map (snd o dest_Free) fns1) @ [T] ---> T';
       
   316           val fns = (flat (take (i, case_dummy_fns))) @
       
   317             fns2 @ (flat (drop (i + 1, case_dummy_fns)));
       
   318           val reccomb = Const (recname, (map fastype_of fns) @ [T] ---> T');
       
   319           val decl = (Sign.base_name name, caseT, NoSyn);
       
   320           val def = ((Sign.base_name name) ^ "_def",
       
   321             Logic.mk_equals (list_comb (Const (name, caseT), fns1),
       
   322               list_comb (reccomb, (flat (take (i, case_dummy_fns))) @
       
   323                 fns2 @ (flat (drop (i + 1, case_dummy_fns))) )));
       
   324           val thy' = thy |>
       
   325             Theory.add_consts_i [decl] |> Theory.add_defs_i [def];
       
   326 
       
   327         in (defs @ [get_def thy' (Sign.base_name name)], thy')
       
   328         end) (([], thy), (hd descr) ~~ newTs ~~ case_names ~~
       
   329           (take (length newTs, reccomb_names)));
       
   330 
       
   331     val case_thms = map (map (fn t => prove_goalw_cterm (case_defs @
       
   332       (map mk_meta_eq primrec_thms)) (cterm_of (sign_of thy2) t)
       
   333         (fn _ => [rtac refl 1])))
       
   334           (DatatypeProp.make_cases new_type_names descr sorts thy2);
       
   335 
       
   336     val thy3 = Theory.add_trrules_i
       
   337       (DatatypeProp.make_case_trrules new_type_names descr) thy2
       
   338 
       
   339   in (store_thmss "cases" new_type_names case_thms thy3, case_names, case_thms)
       
   340   end;
       
   341 
       
   342 (************************ distinctness of constructors ************************)
       
   343 
       
   344 fun prove_distinctness_thms new_type_names descr sorts dist_rewrites case_thms thy =
       
   345   let
       
   346     val descr' = flat descr;
       
   347     val recTs = get_rec_types descr' sorts;
       
   348     val newTs = take (length (hd descr), recTs);
       
   349 
       
   350     (*--------------------------------------------------------------------*)
       
   351     (* define t_ord - functions for proving distinctness of constructors: *)
       
   352     (*  t_ord C_i ... = i                                                 *)
       
   353     (*--------------------------------------------------------------------*)
       
   354 
       
   355     fun define_ord ((thy, ord_defs), (((_, (_, _, constrs)), T), tname)) =
       
   356       if length constrs < DatatypeProp.dtK then (thy, ord_defs)
       
   357       else
       
   358         let
       
   359           val Tss = map ((map (typ_of_dtyp descr' sorts)) o snd) constrs;
       
   360           val ts = map HOLogic.mk_nat (0 upto length constrs - 1);
       
   361           val mk_abs = foldr (fn (T, t') => Abs ("x", T, t'));
       
   362           val fs = map mk_abs (Tss ~~ ts);
       
   363           val fTs = map (fn Ts => Ts ---> HOLogic.natT) Tss;
       
   364           val ord_name = Sign.full_name (sign_of thy) (tname ^ "_ord");
       
   365           val case_name = Sign.intern_const (sign_of thy) (tname ^ "_case");
       
   366           val ordT = T --> HOLogic.natT;
       
   367           val caseT = fTs ---> ordT;
       
   368           val defpair = (tname ^ "_ord_def", Logic.mk_equals
       
   369             (Const (ord_name, ordT), list_comb (Const (case_name, caseT), fs)));
       
   370           val thy' = thy |>
       
   371             Theory.add_consts_i [(tname ^ "_ord", ordT, NoSyn)] |>
       
   372             Theory.add_defs_i [defpair];
       
   373           val def = get_def thy' (tname ^ "_ord")
       
   374 
       
   375         in (thy', ord_defs @ [def]) end;
       
   376 
       
   377     val (thy2, ord_defs) =
       
   378       foldl define_ord ((thy, []), (hd descr) ~~ newTs ~~ new_type_names);
       
   379 
       
   380     (**** number of constructors < dtK ****)
       
   381 
       
   382     fun prove_distinct_thms _ [] = []
       
   383       | prove_distinct_thms dist_rewrites' (t::_::ts) =
       
   384           let
       
   385             val dist_thm = prove_goalw_cterm [] (cterm_of (sign_of thy2) t) (fn _ =>
       
   386               [simp_tac (HOL_ss addsimps dist_rewrites') 1])
       
   387           in dist_thm::(standard (dist_thm RS not_sym))::
       
   388             (prove_distinct_thms dist_rewrites' ts)
       
   389           end;
       
   390 
       
   391     val distinct_thms = map (fn ((((_, (_, _, constrs)), ts),
       
   392       dist_rewrites'), case_thms) =>
       
   393         if length constrs < DatatypeProp.dtK then
       
   394           prove_distinct_thms dist_rewrites' ts
       
   395         else 
       
   396           let
       
   397             val t::ts' = rev ts;
       
   398             val (_ $ (_ $ (_ $ (f $ _) $ _))) = hd (Logic.strip_imp_prems t);
       
   399             val cert = cterm_of (sign_of thy2);
       
   400             val distinct_lemma' = cterm_instantiate
       
   401               [(cert distinct_f, cert f)] distinct_lemma;
       
   402             val rewrites = ord_defs @ (map mk_meta_eq case_thms)
       
   403           in
       
   404             (map (fn t => prove_goalw_cterm rewrites (cert t)
       
   405               (fn _ => [rtac refl 1])) (rev ts')) @ [standard distinct_lemma']
       
   406           end) ((hd descr) ~~ (DatatypeProp.make_distincts new_type_names
       
   407             descr sorts thy2) ~~ dist_rewrites ~~ case_thms)
       
   408 
       
   409   in (store_thmss "distinct" new_type_names distinct_thms thy2, distinct_thms)
       
   410   end;
       
   411 
       
   412 (******************************* case splitting *******************************)
       
   413 
       
   414 fun prove_split_thms new_type_names descr sorts constr_inject dist_rewrites
       
   415     casedist_thms case_thms thy =
       
   416   let
       
   417     val _ = writeln "Proving equations for case splitting...";
       
   418 
       
   419     val descr' = flat descr;
       
   420     val recTs = get_rec_types descr' sorts;
       
   421     val newTs = take (length (hd descr), recTs);
       
   422 
       
   423     fun prove_split_thms ((((((t1, t2), inject), dist_rewrites'),
       
   424         exhaustion), case_thms'), T) =
       
   425       let
       
   426         val cert = cterm_of (sign_of thy);
       
   427         val _ $ (_ $ lhs $ _) = hd (Logic.strip_assums_hyp (hd (prems_of exhaustion)));
       
   428         val exhaustion' = cterm_instantiate
       
   429           [(cert lhs, cert (Free ("x", T)))] exhaustion;
       
   430         val tacsf = K [rtac exhaustion' 1, ALLGOALS (asm_simp_tac
       
   431           (HOL_ss addsimps (dist_rewrites' @ inject @ case_thms')))]
       
   432       in
       
   433         (prove_goalw_cterm [] (cert t1) tacsf,
       
   434          prove_goalw_cterm [] (cert t2) tacsf)
       
   435       end;
       
   436 
       
   437     val split_thm_pairs = map prove_split_thms
       
   438       ((DatatypeProp.make_splits new_type_names descr sorts thy) ~~ constr_inject ~~
       
   439         dist_rewrites ~~ casedist_thms ~~ case_thms ~~ newTs);
       
   440 
       
   441     val (split_thms, split_asm_thms) = ListPair.unzip split_thm_pairs
       
   442 
       
   443   in
       
   444     (thy |> store_thms "split" new_type_names split_thms |>
       
   445             store_thms "split_asm" new_type_names split_asm_thms,
       
   446      split_thm_pairs)
       
   447   end;
       
   448 
       
   449 (******************************* size functions *******************************)
       
   450 
       
   451 fun prove_size_thms new_type_names descr sorts reccomb_names primrec_thms thy =
       
   452   let
       
   453     val _ = writeln "Proving equations for size function...";
       
   454 
       
   455     val descr' = flat descr;
       
   456     val recTs = get_rec_types descr' sorts;
       
   457 
       
   458     val big_size_name = space_implode "_" new_type_names ^ "_size";
       
   459     val size_name = Sign.intern_const (sign_of (the (get_thy "Arith" thy))) "size";
       
   460     val size_names = replicate (length (hd descr)) size_name @
       
   461       map (Sign.full_name (sign_of thy))
       
   462         (if length (flat (tl descr)) = 1 then [big_size_name] else
       
   463           map (fn i => big_size_name ^ "_" ^ string_of_int i)
       
   464             (1 upto length (flat (tl descr))));
       
   465     val def_names = map (fn i => big_size_name ^ "_def_" ^ string_of_int i)
       
   466       (1 upto length recTs);
       
   467 
       
   468     val plus_t = Const ("op +", [HOLogic.natT, HOLogic.natT] ---> HOLogic.natT);
       
   469 
       
   470     fun make_sizefun (_, cargs) =
       
   471       let
       
   472         val Ts = map (typ_of_dtyp descr' sorts) cargs;
       
   473         val k = length (filter is_rec_type cargs);
       
   474         val t = if k = 0 then HOLogic.zero else
       
   475           foldl1 (app plus_t) (map Bound (k - 1 downto 0) @ [HOLogic.mk_nat 1])
       
   476       in
       
   477         foldr (fn (T, t') => Abs ("x", T, t')) (Ts @ replicate k HOLogic.natT, t)
       
   478       end;
       
   479 
       
   480     val fs = flat (map (fn (_, (_, _, constrs)) => map make_sizefun constrs) descr');
       
   481     val fTs = map fastype_of fs;
       
   482 
       
   483     val thy' = thy |>
       
   484       Theory.add_consts_i (map (fn (s, T) =>
       
   485         (Sign.base_name s, T --> HOLogic.natT, NoSyn))
       
   486           (drop (length (hd descr), size_names ~~ recTs))) |>
       
   487       Theory.add_defs_i (map (fn (((s, T), def_name), rec_name) =>
       
   488         (def_name, Logic.mk_equals (Const (s, T --> HOLogic.natT),
       
   489           list_comb (Const (rec_name, fTs @ [T] ---> HOLogic.natT), fs))))
       
   490             (size_names ~~ recTs ~~ def_names ~~ reccomb_names));
       
   491 
       
   492     val size_def_thms = map (get_axiom thy') def_names;
       
   493     val rewrites = size_def_thms @ map mk_meta_eq primrec_thms;
       
   494 
       
   495     val size_thms = map (fn t => prove_goalw_cterm rewrites
       
   496       (cterm_of (sign_of thy') t) (fn _ => [rtac refl 1]))
       
   497         (DatatypeProp.make_size new_type_names descr sorts thy')
       
   498 
       
   499   in
       
   500     (PureThy.add_tthmss [(("size", map Attribute.tthm_of size_thms), [])] thy',
       
   501      size_thms)
       
   502   end;
       
   503 
       
   504 (************************* additional theorems for TFL ************************)
       
   505 
       
   506 fun prove_nchotomys new_type_names descr sorts casedist_thms thy =
       
   507   let
       
   508     val _ = writeln "Proving additional theorems for TFL...";
       
   509 
       
   510     fun prove_nchotomy (t, exhaustion) =
       
   511       let
       
   512         (* For goal i, select the correct disjunct to attack, then prove it *)
       
   513         fun tac i 0 = EVERY [TRY (rtac disjI1 i),
       
   514               hyp_subst_tac i, REPEAT (rtac exI i), rtac refl i]
       
   515           | tac i n = rtac disjI2 i THEN tac i (n - 1)
       
   516       in 
       
   517         prove_goalw_cterm [] (cterm_of (sign_of thy) t) (fn _ =>
       
   518           [rtac allI 1,
       
   519            exh_tac (K exhaustion) 1,
       
   520            ALLGOALS (fn i => tac i (i-1))])
       
   521       end;
       
   522 
       
   523     val nchotomys =
       
   524       map prove_nchotomy (DatatypeProp.make_nchotomys descr sorts ~~ casedist_thms)
       
   525 
       
   526   in
       
   527     (store_thms "nchotomy" new_type_names nchotomys thy, nchotomys)
       
   528   end;
       
   529 
       
   530 fun prove_case_congs new_type_names descr sorts nchotomys case_thms thy =
       
   531   let
       
   532     fun prove_case_cong ((t, nchotomy), case_rewrites) =
       
   533       let
       
   534         val (Const ("==>", _) $ tm $ _) = t;
       
   535         val (Const ("Trueprop", _) $ (Const ("op =", _) $ _ $ Ma)) = tm;
       
   536         val cert = cterm_of (sign_of thy);
       
   537         val nchotomy' = nchotomy RS spec;
       
   538         val nchotomy'' = cterm_instantiate
       
   539           [(cert (hd (add_term_vars (concl_of nchotomy', []))), cert Ma)] nchotomy'
       
   540       in
       
   541         prove_goalw_cterm [] (cert t) (fn prems => 
       
   542           let val simplify = asm_simp_tac (HOL_ss addsimps (prems @ case_rewrites))
       
   543           in [simp_tac (HOL_ss addsimps [hd prems]) 1,
       
   544               cut_facts_tac [nchotomy''] 1,
       
   545               REPEAT (etac disjE 1 THEN REPEAT (etac exE 1) THEN simplify 1),
       
   546               REPEAT (etac exE 1) THEN simplify 1 (* Get last disjunct *)]
       
   547           end)
       
   548       end;
       
   549 
       
   550     val case_congs = map prove_case_cong (DatatypeProp.make_case_congs
       
   551       new_type_names descr sorts thy ~~ nchotomys ~~ case_thms)
       
   552 
       
   553   in
       
   554     (store_thms "case_cong" new_type_names case_congs thy, case_congs)
       
   555   end;
       
   556 
       
   557 end;