src/HOL/datatype.ML
author nipkow
Tue Apr 08 10:48:42 1997 +0200 (1997-04-08)
changeset 2919 953a47dc0519
parent 2880 a0fde30aa126
child 3040 7d48671753da
permissions -rw-r--r--
Dep. on Provers/nat_transitive
     1 (* Title:       HOL/datatype.ML
     2    ID:          $Id$
     3    Author:      Max Breitling, Carsten Clasohm, Tobias Nipkow, Norbert Voelker,
     4                 Konrad Slind
     5    Copyright 1995 TU Muenchen
     6 *)
     7 
     8 
     9 (*used for constructor parameters*)
    10 datatype dt_type = dtVar of string |
    11   dtTyp of dt_type list * string |
    12   dtRek of dt_type list * string;
    13 
    14 structure Datatype =
    15 struct
    16 local 
    17 
    18 val mysort = sort;
    19 open ThyParse HOLogic;
    20 exception Impossible;
    21 exception RecError of string;
    22 
    23 val is_dtRek = (fn dtRek _ => true  |  _  => false);
    24 fun opt_parens s = if s = "" then "" else enclose "(" ")" s; 
    25 
    26 (* ----------------------------------------------------------------------- *)
    27 (* Derivation of the primrec combinator application from the equations     *)
    28 
    29 (* substitute fname(ls,xk,rs) by yk(ls,rs) in t for (xk,yk) in pairs  *) 
    30 
    31 fun subst_apps (_,_) [] t = t
    32   | subst_apps (fname,rpos) pairs t =
    33     let 
    34     fun subst (Abs(a,T,t)) = Abs(a,T,subst t)
    35       | subst (funct $ body) = 
    36         let val (f,b) = strip_comb (funct$body)
    37         in 
    38           if is_Const f andalso fst(dest_Const f) = fname 
    39             then 
    40               let val (ls,rest) = (take(rpos,b), drop(rpos,b));
    41                 val (xk,rs) = (hd rest,tl rest)
    42                   handle LIST _ => raise RecError "not enough arguments \
    43                    \ in recursive application on rhs"
    44               in 
    45                 (case assoc (pairs,xk) of 
    46                    None   => list_comb(f, map subst b)
    47                  | Some U => list_comb(U, map subst (ls @ rs)))
    48               end
    49           else list_comb(f, map subst b)
    50         end
    51       | subst(t) = t
    52     in subst t end;
    53   
    54 (* abstract rhs *)
    55 
    56 fun abst_rec (fname,rpos,tc,ls,cargs,rs,rhs) =       
    57   let val rargs = (map #1 o 
    58                    (filter (fn (a,T) => is_dtRek T))) (cargs ~~ tc);
    59       val subs = map (fn (s,T) => (s,dummyT))
    60                    (rev(rename_wrt_term rhs rargs));
    61       val subst_rhs = subst_apps (fname,rpos)
    62                         (map Free rargs ~~ map Free subs) rhs;
    63   in 
    64       list_abs_free (cargs @ subs @ ls @ rs, subst_rhs) 
    65   end;
    66 
    67 (* parsing the prim rec equations *)
    68 
    69 fun dest_eq ( Const("Trueprop",_) $ (Const ("op =",_) $ lhs $ rhs))
    70                  = (lhs, rhs)
    71    | dest_eq _ = raise RecError "not a proper equation"; 
    72 
    73 fun dest_rec eq = 
    74   let val (lhs,rhs) = dest_eq eq; 
    75     val (name,args) = strip_comb lhs; 
    76     val (ls',rest)  = take_prefix is_Free args; 
    77     val (middle,rs') = take_suffix is_Free rest;
    78     val rpos = length ls';
    79     val (c,cargs') = strip_comb (hd middle)
    80       handle LIST "hd" => raise RecError "constructor missing";
    81     val (ls,cargs,rs) = (map dest_Free ls', map dest_Free cargs'
    82                          , map dest_Free rs')
    83       handle TERM ("dest_Free",_) => 
    84           raise RecError "constructor has illegal argument in pattern";
    85   in 
    86     if length middle > 1 then 
    87       raise RecError "more than one non-variable in pattern"
    88     else if not(null(findrep (map fst (ls @ rs @ cargs)))) then 
    89       raise RecError "repeated variable name in pattern" 
    90          else (fst(dest_Const name) handle TERM _ => 
    91                raise RecError "function is not declared as constant in theory"
    92                  ,rpos,ls,fst( dest_Const c),cargs,rs,rhs)
    93   end; 
    94 
    95 (* check function specified for all constructors and sort function terms *)
    96 
    97 fun check_and_sort (n,its) = 
    98   if length its = n 
    99     then map snd (mysort (fn ((i : int,_),(j,_)) => i<j) its)
   100   else raise error "Primrec definition error:\n\
   101    \Please give an equation for every constructor";
   102 
   103 (* translate rec equations into function arguments suitable for rec comb *)
   104 (* theory parameter needed for printing error messages                   *) 
   105 
   106 fun trans_recs _ _ [] = error("No primrec equations.")
   107   | trans_recs thy cs' (eq1::eqs) = 
   108     let val (name1,rpos1,ls1,_,_,_,_) = dest_rec eq1
   109       handle RecError s =>
   110         error("Primrec definition error: " ^ s ^ ":\n" 
   111               ^ "   " ^ Sign.string_of_term (sign_of thy) eq1);
   112       val tcs = map (fn (_,c,T,_,_) => (c,T)) cs';  
   113       val cs = map fst tcs;
   114       fun trans_recs' _ [] = []
   115         | trans_recs' cis (eq::eqs) = 
   116           let val (name,rpos,ls,c,cargs,rs,rhs) = dest_rec eq; 
   117             val tc = assoc(tcs,c);
   118             val i = (1 + find (c,cs))  handle LIST "find" => 0; 
   119           in
   120           if name <> name1 then 
   121             raise RecError "function names inconsistent"
   122           else if rpos <> rpos1 then 
   123             raise RecError "position of rec. argument inconsistent"
   124           else if i = 0 then 
   125             raise RecError "illegal argument in pattern" 
   126           else if i mem cis then
   127             raise RecError "constructor already occured as pattern "
   128                else (i,abst_rec (name,rpos,the tc,ls,cargs,rs,rhs))
   129                      :: trans_recs' (i::cis) eqs 
   130           end
   131           handle RecError s =>
   132                 error("Primrec definition error\n" ^ s ^ "\n" 
   133                       ^ "   " ^ Sign.string_of_term (sign_of thy) eq);
   134     in (  name1, ls1
   135         , check_and_sort (length cs, trans_recs' [] (eq1::eqs)))
   136     end ;
   137 
   138 in
   139   fun add_datatype (typevars, tname, cons_list') thy = 
   140     let
   141       val dummy = if length cons_list' < dtK then ()
   142                   else require_thy thy "Nat" "datatype";
   143       
   144       fun typid(dtRek(_,id)) = id
   145         | typid(dtVar s) = implode (tl (explode s))
   146         | typid(dtTyp(_,id)) = id;
   147 
   148       fun index_vnames(vn::vns,tab) =
   149             (case assoc(tab,vn) of
   150                None => if vn mem vns
   151                        then (vn^"1") :: index_vnames(vns,(vn,2)::tab)
   152                        else vn :: index_vnames(vns,tab)
   153              | Some(i) => (vn^(string_of_int i)) ::
   154                           index_vnames(vns,(vn,i+1)::tab))
   155         | index_vnames([],tab) = [];
   156 
   157       fun mk_var_names types = index_vnames(map typid types,[]);
   158 
   159       (*search for free type variables and convert recursive *)
   160       fun analyse_types (cons, types, syn) =
   161         let fun analyse(t as dtVar v) =
   162                   if t mem typevars then t
   163                   else error ("Free type variable " ^ v ^ " on rhs.")
   164               | analyse(dtTyp(typl,s)) =
   165                   if tname <> s then dtTyp(analyses typl, s)
   166                   else if typevars = typl then dtRek(typl, s)
   167                        else error (s ^ " used in different ways")
   168               | analyse(dtRek _) = raise Impossible
   169             and analyses ts = map analyse ts;
   170         in (cons, Syntax.const_name cons syn, analyses types,
   171             mk_var_names types, syn)
   172         end;
   173 
   174      (*test if all elements are recursive, i.e. if the type is empty*)
   175       
   176       fun non_empty (cs : ('a * 'b * dt_type list * 'c *'d) list) = 
   177         not(forall (exists is_dtRek o #3) cs) orelse
   178         error("Empty datatype not allowed!");
   179 
   180       val cons_list = map analyse_types cons_list';
   181       val dummy = non_empty cons_list;
   182       val num_of_cons = length cons_list;
   183 
   184      (* Auxiliary functions to construct argument and equation lists *)
   185 
   186      (*generate 'var_n, ..., var_m'*)
   187       fun Args(var, delim, n, m) = 
   188         space_implode delim (map (fn n => var^string_of_int(n)) (n upto m));
   189 
   190       fun C_exp name vns = name ^ opt_parens(space_implode ") (" vns);
   191 
   192      (*Arg_eqs([x1,...,xn],[y1,...,yn]) = "x1 = y1 & ... & xn = yn" *)
   193       fun arg_eqs vns vns' =
   194         let fun mkeq(x,x') = x ^ "=" ^ x'
   195         in space_implode " & " (ListPair.map mkeq (vns,vns')) end;
   196 
   197      (*Pretty printers for type lists;
   198        pp_typlist1: parentheses, pp_typlist2: brackets*)
   199       fun pp_typ (dtVar s) = "(" ^ s ^ "::term)"
   200         | pp_typ (dtTyp (typvars, id)) =
   201           if null typvars then id else (pp_typlist1 typvars) ^ id
   202         | pp_typ (dtRek (typvars, id)) = (pp_typlist1 typvars) ^ id
   203       and
   204         pp_typlist' ts = commas (map pp_typ ts)
   205       and
   206         pp_typlist1 ts = if null ts then "" else parens (pp_typlist' ts);
   207 
   208       fun pp_typlist2 ts = if null ts then "" else brackets (pp_typlist' ts);
   209 
   210      (* Generate syntax translation for case rules *)
   211       fun calc_xrules c_nr y_nr ((_, name, _, vns, _) :: cs) = 
   212         let val arity = length vns;
   213           val body  = "z" ^ string_of_int(c_nr);
   214           val args1 = if arity=0 then ""
   215                       else " " ^ Args ("y", " ", y_nr, y_nr+arity-1);
   216           val args2 = if arity=0 then ""
   217                       else "(% " ^ Args ("y", " ", y_nr, y_nr+arity-1) 
   218                         ^ ". ";
   219           val (rest1,rest2) = 
   220             if null cs then ("","")
   221             else let val (h1, h2) = calc_xrules (c_nr+1) (y_nr+arity) cs
   222             in (" | " ^ h1, " " ^ h2) end;
   223         in (name ^ args1 ^ " => " ^ body ^ rest1,
   224             args2 ^ body ^ (if args2 = "" then "" else ")") ^ rest2)
   225         end
   226         | calc_xrules _ _ [] = raise Impossible;
   227       
   228       val xrules =
   229         let val (first_part, scnd_part) = calc_xrules 1 1 cons_list
   230         in [Syntax.<-> (("logic", "case x of " ^ first_part),
   231                         ("logic", tname ^ "_case " ^ scnd_part ^ " x"))]
   232         end;
   233 
   234      (*type declarations for constructors*)
   235       fun const_type (id, _, typlist, _, syn) =
   236         (id,  
   237          (if null typlist then "" else pp_typlist2 typlist ^ " => ") ^
   238             pp_typlist1 typevars ^ tname, syn);
   239 
   240 
   241       fun assumpt (dtRek _ :: ts, v :: vs ,found) =
   242         let val h = if found then ";P(" ^ v ^ ")" else "[| P(" ^ v ^ ")"
   243         in h ^ (assumpt (ts, vs, true)) end
   244         | assumpt (t :: ts, v :: vs, found) = assumpt (ts, vs, found)
   245       | assumpt ([], [], found) = if found then "|] ==>" else ""
   246         | assumpt _ = raise Impossible;
   247 
   248       fun t_inducting ((_, name, types, vns, _) :: cs) =
   249         let
   250           val h = if null types then " P(" ^ name ^ ")"
   251                   else " !!" ^ (space_implode " " vns) ^ "." ^
   252                     (assumpt (types, vns, false)) ^
   253                     "P(" ^ C_exp name vns ^ ")";
   254           val rest = t_inducting cs;
   255         in if rest = "" then h else h ^ "; " ^ rest end
   256         | t_inducting [] = "";
   257 
   258       fun t_induct cl typ_name =
   259         "[|" ^ t_inducting cl ^ "|] ==> P(" ^ typ_name ^ ")";
   260 
   261       fun gen_typlist typevar f ((_, _, ts, _, _) :: cs) =
   262         let val h = if (length ts) > 0
   263                       then pp_typlist2(f ts) ^ "=>"
   264                     else ""
   265         in h ^ typevar ^  "," ^ (gen_typlist typevar f cs) end
   266         | gen_typlist _ _ [] = "";
   267 
   268 
   269 (* -------------------------------------------------------------------- *)
   270 (* The case constant and rules                                          *)
   271                 
   272       val t_case = tname ^ "_case";
   273 
   274       fun case_rule n (id, name, _, vns, _) =
   275         let val args = if vns = [] then "" else " " ^ space_implode " " vns
   276         in (t_case ^ "_" ^ id,
   277             t_case ^ " " ^ Args("f", " ", 1, num_of_cons)
   278             ^ " (" ^ name ^ args ^ ") = f"^string_of_int(n) ^ args)
   279         end
   280 
   281       fun case_rules n (c :: cs) = case_rule n c :: case_rules(n+1) cs
   282         | case_rules _ [] = [];
   283 
   284       val datatype_arity = length typevars;
   285 
   286       val types = [(tname, datatype_arity, NoSyn)];
   287 
   288       val arities = 
   289         let val term_list = replicate datatype_arity termS;
   290         in [(tname, term_list, termS)] 
   291         end;
   292 
   293       val datatype_name = pp_typlist1 typevars ^ tname;
   294 
   295       val new_tvar_name = variant (map (fn dtVar s => s) typevars) "'z";
   296 
   297       val case_const =
   298         (t_case,
   299          "[" ^ gen_typlist new_tvar_name I cons_list 
   300          ^  pp_typlist1 typevars ^ tname ^ "] =>" ^ new_tvar_name^"::term",
   301          NoSyn);
   302 
   303       val rules_case = case_rules 1 cons_list;
   304 
   305 (* -------------------------------------------------------------------- *)
   306 (* The prim-rec combinator                                              *) 
   307 
   308       val t_rec = tname ^ "_rec"
   309 
   310 (* adding type variables for dtRek types to end of list of dt_types      *)   
   311 
   312       fun add_reks ts = 
   313         ts @ map (fn _ => dtVar new_tvar_name) (filter is_dtRek ts); 
   314 
   315 (* positions of the dtRek types in a list of dt_types, starting from 1  *)
   316       fun rek_vars ts vns = map #2 (filter (is_dtRek o fst) (ts ~~ vns))
   317 
   318       fun rec_rule n (id,name,ts,vns,_) = 
   319         let val args = opt_parens(space_implode ") (" vns)
   320           val fargs = opt_parens(Args("f", ") (", 1, num_of_cons))
   321           fun rarg vn = t_rec ^ fargs ^ " (" ^ vn ^ ")"
   322           val rargs = opt_parens(space_implode ") ("
   323                                  (map rarg (rek_vars ts vns)))
   324         in
   325           (t_rec ^ "_" ^ id,
   326            t_rec ^ fargs ^ " (" ^ name ^ args ^ ") = f"
   327            ^ string_of_int(n) ^ args ^ rargs)
   328         end
   329 
   330       fun rec_rules n (c::cs) = rec_rule n c :: rec_rules (n+1) cs 
   331         | rec_rules _ [] = [];
   332 
   333       val rec_const =
   334         (t_rec,
   335          "[" ^ (gen_typlist new_tvar_name add_reks cons_list) 
   336          ^ (pp_typlist1 typevars) ^ tname ^ "] =>" ^ new_tvar_name^"::term",
   337          NoSyn);
   338 
   339       val rules_rec = rec_rules 1 cons_list
   340 
   341 (* -------------------------------------------------------------------- *)
   342       val consts = 
   343         map const_type cons_list
   344         @ (if num_of_cons < dtK then []
   345            else [(tname ^ "_ord", datatype_name ^ "=>nat", NoSyn)])
   346         @ [case_const,rec_const];
   347 
   348 
   349       fun Ci_ing ((id, name, _, vns, _) :: cs) =
   350            if null vns then Ci_ing cs
   351            else let val vns' = variantlist(vns,vns)
   352                 in ("inject_" ^ id,
   353                     "(" ^ (C_exp name vns) ^ "=" ^ (C_exp name vns')
   354                     ^ ") = (" ^ (arg_eqs vns vns') ^ ")") :: (Ci_ing cs)
   355                 end
   356         | Ci_ing [] = [];
   357 
   358       fun Ci_negOne (id1,name1,_,vns1,_) (id2,name2,_,vns2,_) =
   359             let val vns2' = variantlist(vns2,vns1)
   360                 val ax = C_exp name1 vns1 ^ "~=" ^ C_exp name2 vns2'
   361         in (id1 ^ "_not_" ^ id2, ax) end;
   362 
   363       fun Ci_neg1 [] = []
   364         | Ci_neg1 (c1::cs) = (map (Ci_negOne c1) cs) @ Ci_neg1 cs;
   365 
   366       fun suc_expr n = 
   367         if n=0 then "0" else "Suc(" ^ suc_expr(n-1) ^ ")";
   368 
   369       fun Ci_neg2() =
   370         let val ord_t = tname ^ "_ord";
   371           val cis = ListPair.zip (cons_list, 0 upto (num_of_cons - 1))
   372           fun Ci_neg2equals ((id, name, _, vns, _), n) =
   373             let val ax = ord_t ^ "(" ^ (C_exp name vns) ^ ") = " ^ (suc_expr n)
   374             in (ord_t ^ "_" ^ id, ax) end
   375         in (ord_t ^ "_distinct", ord_t^"(x) ~= "^ord_t^"(y) ==> x ~= y") ::
   376           (map Ci_neg2equals cis)
   377         end;
   378 
   379       val rules_distinct = if num_of_cons < dtK then Ci_neg1 cons_list
   380                            else Ci_neg2();
   381 
   382       val rules_inject = Ci_ing cons_list;
   383 
   384       val rule_induct = (tname ^ "_induct", t_induct cons_list tname);
   385 
   386       val rules = rule_induct ::
   387         (rules_inject @ rules_distinct @ rules_case @ rules_rec);
   388 
   389       fun add_primrec eqns thy =
   390         let val rec_comb = Const(t_rec,dummyT)
   391           val teqns = map (fn neq => snd(read_axm (sign_of thy) neq)) eqns
   392           val (fname,ls,fns) = trans_recs thy cons_list teqns
   393           val rhs = 
   394             list_abs_free
   395             (ls @ [(tname,dummyT)]
   396              ,list_comb(rec_comb
   397                         , fns @ map Bound (0 ::(length ls downto 1))));
   398           val sg = sign_of thy;
   399           val defpair = (fname ^ "_" ^ tname ^ "_def",
   400                          Logic.mk_equals (Const(fname,dummyT), rhs))
   401           val defpairT as (_, _ $ Const(_,T) $ _ ) = inferT_axm sg defpair;
   402           val varT = Type.varifyT T;
   403           val ftyp = the (Sign.const_type sg fname);
   404         in add_defs_i [defpairT] thy end;
   405 
   406     in
   407       (thy |> add_types types
   408            |> add_arities arities
   409            |> add_consts consts
   410            |> add_trrules xrules
   411            |> add_axioms rules, add_primrec)
   412     end
   413 end
   414 end
   415 
   416 (*
   417 Informal description of functions used in datatype.ML for the Isabelle/HOL
   418 implementation of prim. rec. function definitions. (N. Voelker, Feb. 1995) 
   419 
   420 * subst_apps (fname,rpos) pairs t:
   421    substitute the term 
   422        fname(ls,xk,rs) 
   423    by 
   424       yk(ls,rs) 
   425    in t for (xk,yk) in pairs, where rpos = length ls. 
   426    Applied with : 
   427      fname = function name 
   428      rpos = position of recursive argument 
   429      pairs = list of pairs (xk,yk), where 
   430           xk are the rec. arguments of the constructor in the pattern,
   431           yk is a variable with name derived from xk 
   432      t = rhs of equation 
   433 
   434 * abst_rec (fname,rpos,tc,ls,cargs,rs,rhs)
   435   - filter recursive arguments from constructor arguments cargs,
   436   - perform substitutions on rhs, 
   437   - derive list subs of new variable names yk for use in subst_apps, 
   438   - abstract rhs with respect to cargs, subs, ls and rs. 
   439 
   440 * dest_eq t 
   441   destruct a term denoting an equation into lhs and rhs. 
   442 
   443 * dest_req eq 
   444   destruct an equation of the form 
   445       name (vl1..vlrpos, Ci(vi1..vin), vr1..vrn) = rhs
   446   into 
   447   - function name  (name) 
   448   - position of the first non-variable parameter  (rpos)
   449   - the list of first rpos parameters (ls = [vl1..vlrpos]) 
   450   - the constructor (fst( dest_Const c) = Ci)
   451   - the arguments of the constructor (cargs = [vi1..vin])
   452   - the rest of the variables in the pattern (rs = [vr1..vrn])
   453   - the right hand side of the equation (rhs).  
   454  
   455 * check_and_sort (n,its)
   456   check that  n = length its holds, and sort elements of its by 
   457   first component. 
   458 
   459 * trans_recs thy cs' (eq1::eqs)
   460   destruct eq1 into name1, rpos1, ls1, etc.. 
   461   get constructor list with and without type (tcs resp. cs) from cs',  
   462   for every equation:  
   463     destruct it into (name,rpos,ls,c,cargs,rs,rhs)
   464     get typed constructor tc from c and tcs 
   465     determine the index i of the constructor 
   466     check function name and position of rec. argument by comparison
   467     with first equation 
   468     check for repeated variable names in pattern
   469     derive function term f_i which is used as argument of the rec. combinator
   470     sort the terms f_i according to i and return them together
   471       with the function name and the parameter of the definition (ls). 
   472 
   473 * Application:
   474 
   475   The rec. combinator is applied to the function terms resulting from
   476   trans_rec. This results in a function which takes the recursive arg. 
   477   as first parameter and then the arguments corresponding to ls. The
   478   order of parameters is corrected by setting the rhs equal to 
   479 
   480   list_abs_free
   481             (ls @ [(tname,dummyT)]
   482              ,list_comb(rec_comb
   483                         , fns @ map Bound (0 ::(length ls downto 1))));
   484 
   485   Note the de-Bruijn indices counting the number of lambdas between the
   486   variable and its binding. 
   487 *)
   488 
   489 
   490 
   491 (* ----------------------------------------------- *)
   492 (* The following has been written by Konrad Slind. *)
   493 
   494 
   495 type dtype_info = {case_const:term, case_rewrites:thm list,
   496                    constructors:term list, nchotomy:thm, case_cong:thm};
   497 
   498 signature Dtype_sig =
   499 sig
   500   val build_case_cong: Sign.sg -> thm list -> cterm
   501   val build_nchotomy: Sign.sg -> thm list -> cterm
   502 
   503   val prove_case_cong: thm -> thm list -> cterm -> thm
   504   val prove_nchotomy: (string -> int -> tactic) -> cterm -> thm
   505 
   506   val case_thms : Sign.sg -> thm list -> (string -> int -> tactic)
   507                    -> {nchotomy:thm, case_cong:thm}
   508 
   509   val build_record : (theory * (string * string list)
   510                       * (string -> int -> tactic))
   511                      -> (string * dtype_info) 
   512 
   513 end;
   514 
   515 
   516 (*---------------------------------------------------------------------------
   517  * This structure is support for the Isabelle datatype package. It provides
   518  * entrypoints for 1) building and proving the case congruence theorem for
   519  * a datatype and 2) building and proving the "exhaustion" theorem for
   520  * a datatype (I have called this theorem "nchotomy" for no good reason).
   521  *
   522  * It also brings all these together in the function "build_record", which
   523  * is probably what will be used.
   524  *
   525  * Since these routines are required in order to support TFL, they have
   526  * been written so they will compile "stand-alone", i.e., in Isabelle-HOL
   527  * without any TFL code around.
   528  *---------------------------------------------------------------------------*)
   529 structure Dtype : Dtype_sig =
   530 struct
   531 
   532 exception DTYPE_ERR of {func:string, mesg:string};
   533 
   534 (*---------------------------------------------------------------------------
   535  * General support routines
   536  *---------------------------------------------------------------------------*)
   537 fun itlist f L base_value =
   538    let fun it [] = base_value
   539          | it (a::rst) = f a (it rst)
   540    in it L 
   541    end;
   542 
   543 fun end_itlist f =
   544 let fun endit [] = raise DTYPE_ERR{func="end_itlist", mesg="list too short"}
   545       | endit alist = 
   546          let val (base::ralist) = rev alist
   547          in itlist f (rev ralist) base  end
   548 in endit
   549 end;
   550 
   551 fun unzip L = itlist (fn (x,y) => fn (l1,l2) =>((x::l1),(y::l2))) L ([],[]);
   552 
   553 
   554 (*---------------------------------------------------------------------------
   555  * Miscellaneous Syntax manipulation
   556  *---------------------------------------------------------------------------*)
   557 val mk_var = Free;
   558 val mk_const = Const
   559 fun mk_comb(Rator,Rand) = Rator $ Rand;
   560 fun mk_abs(r as (Var((s,_),ty),_))  = Abs(s,ty,abstract_over r)
   561   | mk_abs(r as (Free(s,ty),_))     = Abs(s,ty,abstract_over r)
   562   | mk_abs _ = raise DTYPE_ERR{func="mk_abs", mesg="1st not a variable"};
   563 
   564 fun dest_var(Var((s,i),ty)) = (s,ty)
   565   | dest_var(Free(s,ty))    = (s,ty)
   566   | dest_var _ = raise DTYPE_ERR{func="dest_var", mesg="not a variable"};
   567 
   568 fun dest_const(Const p) = p
   569   | dest_const _ = raise DTYPE_ERR{func="dest_const", mesg="not a constant"};
   570 
   571 fun dest_comb(t1 $ t2) = (t1,t2)
   572   | dest_comb _ =  raise DTYPE_ERR{func = "dest_comb", mesg = "not a comb"};
   573 val rand = #2 o dest_comb;
   574 val rator = #1 o dest_comb;
   575 
   576 fun dest_abs(a as Abs(s,ty,M)) = 
   577      let val v = Free(s, ty)
   578       in (v, betapply (a,v)) end
   579   | dest_abs _ =  raise DTYPE_ERR{func="dest_abs", mesg="not an abstraction"};
   580 
   581 
   582 val bool = Type("bool",[])
   583 and prop = Type("prop",[]);
   584 
   585 fun mk_eq(lhs,rhs) = 
   586    let val ty = type_of lhs
   587        val c = mk_const("op =", ty --> ty --> bool)
   588    in list_comb(c,[lhs,rhs])
   589    end
   590 
   591 fun dest_eq(Const("op =",_) $ M $ N) = (M, N)
   592   | dest_eq _ = raise DTYPE_ERR{func="dest_eq", mesg="not an equality"};
   593 
   594 fun mk_disj(disj1,disj2) =
   595    let val c = Const("op |", bool --> bool --> bool)
   596    in list_comb(c,[disj1,disj2])
   597    end;
   598 
   599 fun mk_forall (r as (Bvar,_)) = 
   600   let val ty = type_of Bvar
   601       val c = Const("All", (ty --> bool) --> bool)
   602   in mk_comb(c, mk_abs r)
   603   end;
   604 
   605 fun mk_exists (r as (Bvar,_)) = 
   606   let val ty = type_of Bvar 
   607       val c = Const("Ex", (ty --> bool) --> bool)
   608   in mk_comb(c, mk_abs r)
   609   end;
   610 
   611 fun mk_prop (tm as Const("Trueprop",_) $ _) = tm
   612   | mk_prop tm = mk_comb(Const("Trueprop", bool --> prop),tm);
   613 
   614 fun drop_prop (Const("Trueprop",_) $ X) = X
   615   | drop_prop X = X;
   616 
   617 fun mk_all (r as (Bvar,_)) = mk_comb(all (type_of Bvar), mk_abs r);
   618 fun list_mk_all(V,t) = itlist(fn v => fn b => mk_all(v,b)) V t;
   619 fun list_mk_exists(V,t) = itlist(fn v => fn b => mk_exists(v,b)) V t;
   620 val list_mk_disj = end_itlist(fn d1 => fn tm => mk_disj(d1,tm))
   621 
   622 
   623 fun dest_thm thm = 
   624    let val {prop,hyps,...} = rep_thm thm
   625    in (map drop_prop hyps, drop_prop prop)
   626    end;
   627 
   628 val concl = #2 o dest_thm;
   629 
   630 
   631 (*---------------------------------------------------------------------------
   632  * Names of all variables occurring in a term, including bound ones. These
   633  * are added into the second argument.
   634  *---------------------------------------------------------------------------*)
   635 fun add_term_names tm =
   636 let fun insert (x:string) = 
   637      let fun canfind[] = [x] 
   638            | canfind(alist as (y::rst)) = 
   639               if (x<y) then x::alist
   640               else if (x=y) then y::rst
   641               else y::canfind rst 
   642      in canfind end
   643     fun add (Free(s,_)) V = insert s V
   644       | add (Var((s,_),_)) V = insert s V
   645       | add (Abs(s,_,body)) V = add body (insert s V)
   646       | add (f$t) V = add t (add f V)
   647       | add _ V = V
   648 in add tm
   649 end;
   650 
   651 
   652 (*---------------------------------------------------------------------------
   653  * We need to make everything free, so that we can put the term into a
   654  * goalstack, or submit it as an argument to prove_goalw_cterm.
   655  *---------------------------------------------------------------------------*)
   656 fun make_free_ty(Type(s,alist)) = Type(s,map make_free_ty alist)
   657   | make_free_ty(TVar((s,i),srt)) = TFree(s,srt)
   658   | make_free_ty x = x;
   659 
   660 fun make_free (Var((s,_),ty)) = Free(s,make_free_ty ty)
   661   | make_free (Abs(s,x,body)) = Abs(s,make_free_ty x, make_free body)
   662   | make_free (f$t) = (make_free f $ make_free t)
   663   | make_free (Const(s,ty)) = Const(s, make_free_ty ty)
   664   | make_free (Free(s,ty)) = Free(s, make_free_ty ty)
   665   | make_free b = b;
   666 
   667 
   668 (*---------------------------------------------------------------------------
   669  * Structure of case congruence theorem looks like this:
   670  *
   671  *    (M = M') 
   672  *    ==> (!!x1,...,xk. (M' = C1 x1..xk) ==> (f1 x1..xk = f1' x1..xk)) 
   673  *    ==> ... 
   674  *    ==> (!!x1,...,xj. (M' = Cn x1..xj) ==> (fn x1..xj = fn' x1..xj)) 
   675  *    ==>
   676  *      (ty_case f1..fn M = ty_case f1'..fn' m')
   677  *
   678  * The input is the list of rules for the case construct for the type, i.e.,
   679  * that found in the "ty.cases" field of a theory where datatype "ty" is
   680  * defined.
   681  *---------------------------------------------------------------------------*)
   682 
   683 fun build_case_cong sign case_rewrites =
   684  let val clauses = map concl case_rewrites
   685      val clause1 = hd clauses
   686      val left = (#1 o dest_eq) clause1
   687      val ty = type_of ((#2 o dest_comb) left)
   688      val varnames = itlist add_term_names clauses []
   689      val M = variant varnames "M"
   690      val Mvar = Free(M, ty)
   691      val M' = variant (M::varnames) M
   692      val M'var = Free(M', ty)
   693      fun mk_clause clause =
   694        let val (lhs,rhs) = dest_eq clause
   695            val func = (#1 o strip_comb) rhs
   696            val (constr,xbar) = strip_comb(rand lhs)
   697            val (Name,Ty) = dest_var func
   698            val func'name = variant (M::M'::varnames) (Name^"a")
   699            val func' = mk_var(func'name,Ty)
   700        in (func', list_mk_all
   701                   (xbar, Logic.mk_implies
   702                          (mk_prop(mk_eq(M'var, list_comb(constr,xbar))),
   703                           mk_prop(mk_eq(list_comb(func, xbar),
   704                                         list_comb(func',xbar))))))   end
   705      val (funcs',clauses') = unzip (map mk_clause clauses)
   706      val lhsM = mk_comb(rator left, Mvar)
   707      val c = #1(strip_comb left)
   708  in
   709  cterm_of sign
   710   (make_free
   711    (Logic.list_implies(mk_prop(mk_eq(Mvar, M'var))::clauses',
   712                        mk_prop(mk_eq(lhsM, list_comb(c,(funcs'@[M'var])))))))
   713  end
   714  handle _ => raise DTYPE_ERR{func="build_case_cong",mesg="failed"};
   715 
   716   
   717 (*---------------------------------------------------------------------------
   718  * Proves the result of "build_case_cong". 
   719  * This one solves it a disjunct at a time, and builds the ss only once.
   720  *---------------------------------------------------------------------------*)
   721 fun prove_case_cong nchotomy case_rewrites ctm =
   722  let val {sign,t,...} = rep_cterm ctm
   723      val (Const("==>",_) $ tm $ _) = t
   724      val (Const("Trueprop",_) $ (Const("op =",_) $ _ $ Ma)) = tm
   725      val (Free(str,_)) = Ma
   726      val thm = prove_goalw_cterm[] ctm
   727       (fn prems => 
   728         let val simplify = asm_simp_tac(HOL_ss addsimps (prems@case_rewrites))
   729         in [simp_tac (HOL_ss addsimps [hd prems]) 1,
   730             cut_inst_tac [("x",str)] (nchotomy RS spec) 1,
   731             REPEAT (etac disjE 1 THEN REPEAT (etac exE 1) THEN simplify 1),
   732             REPEAT (etac exE 1) THEN simplify 1 (* Get last disjunct *)]
   733         end) 
   734  in standard (thm RS eq_reflection)
   735  end
   736  handle _ => raise DTYPE_ERR{func="prove_case_cong",mesg="failed"};
   737 
   738 
   739 (*---------------------------------------------------------------------------
   740  * Structure of exhaustion theorem looks like this:
   741  *
   742  *    !v. (EX y1..yi. v = C1 y1..yi) | ... | (EX y1..yj. v = Cn y1..yj)
   743  *
   744  * As for "build_case_cong", the input is the list of rules for the case 
   745  * construct (the case "rewrites").
   746  *---------------------------------------------------------------------------*)
   747 fun build_nchotomy sign case_rewrites =
   748  let val clauses = map concl case_rewrites
   749      val C_ybars = map (rand o #1 o dest_eq) clauses
   750      val varnames = itlist add_term_names C_ybars []
   751      val vname = variant varnames "v"
   752      val ty = type_of (hd C_ybars)
   753      val v = mk_var(vname,ty)
   754      fun mk_disj C_ybar =
   755        let val ybar = #2(strip_comb C_ybar)
   756        in list_mk_exists(ybar, mk_eq(v,C_ybar))
   757        end
   758  in
   759  cterm_of sign
   760    (make_free(mk_prop (mk_forall(v, list_mk_disj (map mk_disj C_ybars)))))
   761  end
   762  handle _ => raise DTYPE_ERR{func="build_nchotomy",mesg="failed"};
   763 
   764 
   765 (*---------------------------------------------------------------------------
   766  * Takes the induction tactic for the datatype, and the result from 
   767  * "build_nchotomy" 
   768  *
   769  *    !v. (EX y1..yi. v = C1 y1..yi) | ... | (EX y1..yj. v = Cn y1..yj)
   770  *
   771  * and proves the theorem. The proof works along a diagonal: the nth 
   772  * disjunct in the nth subgoal is easy to solve. Thus this routine depends 
   773  * on the order of goals arising out of the application of the induction 
   774  * tactic. A more general solution would have to use injectiveness and 
   775  * distinctness rewrite rules.
   776  *---------------------------------------------------------------------------*)
   777 fun prove_nchotomy induct_tac ctm =
   778  let val (Const ("Trueprop",_) $ g) = #t(rep_cterm ctm)
   779      val (Const ("All",_) $ Abs (v,_,_)) = g
   780      (* For goal i, select the correct disjunct to attack, then prove it *)
   781      fun tac i 0 = (rtac disjI1 i ORELSE all_tac) THEN
   782                    REPEAT (rtac exI i) THEN (rtac refl i)
   783        | tac i n = rtac disjI2 i THEN tac i (n-1)
   784  in 
   785  prove_goalw_cterm[] ctm
   786      (fn _ => [rtac allI 1,
   787                induct_tac v 1,
   788                ALLGOALS (fn i => tac i (i-1))])
   789  end
   790  handle _ => raise DTYPE_ERR {func="prove_nchotomy", mesg="failed"};
   791 
   792 
   793 (*---------------------------------------------------------------------------
   794  * Brings the preceeding functions together.
   795  *---------------------------------------------------------------------------*)
   796 fun case_thms sign case_rewrites induct_tac =
   797   let val nchotomy = prove_nchotomy induct_tac
   798                                     (build_nchotomy sign case_rewrites)
   799       val cong = prove_case_cong nchotomy case_rewrites
   800                                  (build_case_cong sign case_rewrites)
   801   in {nchotomy=nchotomy, case_cong=cong}
   802   end;
   803 
   804 
   805 (*---------------------------------------------------------------------------
   806  * Tests
   807  *
   808  * 
   809      Dtype.case_thms (sign_of List.thy) List.list.cases List.list.induct_tac;
   810      Dtype.case_thms (sign_of Prod.thy) [split] 
   811                      (fn s => res_inst_tac [("p",s)] PairE_lemma);
   812      Dtype.case_thms (sign_of Nat.thy) [nat_case_0, nat_case_Suc] nat_ind_tac;
   813 
   814  *
   815  *---------------------------------------------------------------------------*)
   816 
   817 
   818 (*---------------------------------------------------------------------------
   819  * Given a theory and the name (and constructors) of a datatype declared in 
   820  * an ancestor of that theory and an induction tactic for that datatype, 
   821  * return the information that TFL needs. This should only be called once for
   822  * a datatype, because "build_record" proves various facts, and thus is slow. 
   823  * It fails on the datatype of pairs, which must be included for TFL to work. 
   824  * The test shows how to  build the record for pairs.
   825  *---------------------------------------------------------------------------*)
   826 
   827 local fun mk_rw th = (th RS eq_reflection) handle _ => th
   828       fun get_fact thy s = (get_axiom thy s handle _ => get_thm thy s)
   829 in
   830 fun build_record (thy,(ty,cl),itac) =
   831  let val sign = sign_of thy
   832      fun const s = Const(s, the(Sign.const_type sign s))
   833      val case_rewrites = map (fn c => get_fact thy (ty^"_case_"^c)) cl
   834      val {nchotomy,case_cong} = case_thms sign case_rewrites itac
   835  in
   836   (ty, {constructors = map(fn s => const s handle _ => const("op "^s)) cl,
   837         case_const = const (ty^"_case"),
   838         case_rewrites = map mk_rw case_rewrites,
   839         nchotomy = nchotomy,
   840         case_cong = case_cong})
   841  end
   842 end;
   843 
   844 
   845 (*---------------------------------------------------------------------------
   846  * Test
   847  *
   848  * 
   849     map Dtype.build_record 
   850           [(Nat.thy, ("nat",["0", "Suc"]), nat_ind_tac),
   851            (List.thy,("list",["[]", "#"]), List.list.induct_tac)]
   852     @
   853     [let val prod_case_thms = Dtype.case_thms (sign_of Prod.thy) [split] 
   854                                  (fn s => res_inst_tac [("p",s)] PairE_lemma)
   855          fun const s = Const(s, the(Sign.const_type (sign_of Prod.thy) s))
   856      in ("*", 
   857          {constructors = [const "Pair"],
   858             case_const = const "split",
   859          case_rewrites = [split RS eq_reflection],
   860              case_cong = #case_cong prod_case_thms,
   861               nchotomy = #nchotomy prod_case_thms}) end];
   862 
   863  *
   864  *---------------------------------------------------------------------------*)
   865 
   866 end;