src/HOL/datatype.ML
author clasohm
Wed Mar 13 11:55:25 1996 +0100 (1996-03-13)
changeset 1574 5a63ab90ee8a
parent 1465 5d7a7e439cec
child 1668 8ead1fe65aad
permissions -rw-r--r--
modified primrec so it can be used in MiniML/Type.thy
     1 (* Title:       HOL/datatype.ML
     2    ID:          $Id$
     3    Author:      Max Breitling, Carsten Clasohm, Tobias Nipkow, Norbert Voelker
     4    Copyright 1995 TU Muenchen
     5 *)
     6 
     7 
     8 (*used for constructor parameters*)
     9 datatype dt_type = dtVar of string |
    10   dtTyp of dt_type list * string |
    11   dtRek of dt_type list * string;
    12 
    13 structure Datatype =
    14 struct
    15 local 
    16 
    17 val mysort = sort;
    18 open ThyParse HOLogic;
    19 exception Impossible;
    20 exception RecError of string;
    21 
    22 val is_dtRek = (fn dtRek _ => true  |  _  => false);
    23 fun opt_parens s = if s = "" then "" else enclose "(" ")" s; 
    24 
    25 (* ----------------------------------------------------------------------- *)
    26 (* Derivation of the primrec combinator application from the equations     *)
    27 
    28 (* substitute fname(ls,xk,rs) by yk(ls,rs) in t for (xk,yk) in pairs  *) 
    29 
    30 fun subst_apps (_,_) [] t = t
    31   | subst_apps (fname,rpos) pairs t =
    32     let 
    33     fun subst (Abs(a,T,t)) = Abs(a,T,subst t)
    34       | subst (funct $ body) = 
    35         let val (f,b) = strip_comb (funct$body)
    36         in 
    37           if is_Const f andalso fst(dest_Const f) = fname 
    38             then 
    39               let val (ls,rest) = (take(rpos,b), drop(rpos,b));
    40                 val (xk,rs) = (hd rest,tl rest)
    41                   handle LIST _ => raise RecError "not enough arguments \
    42                    \ in recursive application on rhs"
    43               in 
    44                 (case assoc (pairs,xk) of 
    45                    None   => list_comb(f, map subst b)
    46                  | Some U => list_comb(U, map subst (ls @ rs)))
    47               end
    48           else list_comb(f, map subst b)
    49         end
    50       | subst(t) = t
    51     in subst t end;
    52   
    53 (* abstract rhs *)
    54 
    55 fun abst_rec (fname,rpos,tc,ls,cargs,rs,rhs) =       
    56   let val rargs = (map fst o 
    57                    (filter (fn (a,T) => is_dtRek T))) (cargs ~~ tc);
    58       val subs = map (fn (s,T) => (s,dummyT))
    59                    (rev(rename_wrt_term rhs rargs));
    60       val subst_rhs = subst_apps (fname,rpos)
    61                         (map Free rargs ~~ map Free subs) rhs;
    62   in 
    63       list_abs_free (cargs @ subs @ ls @ rs, subst_rhs) 
    64   end;
    65 
    66 (* parsing the prim rec equations *)
    67 
    68 fun dest_eq ( Const("Trueprop",_) $ (Const ("op =",_) $ lhs $ rhs))
    69                  = (lhs, rhs)
    70    | dest_eq _ = raise RecError "not a proper equation"; 
    71 
    72 fun dest_rec eq = 
    73   let val (lhs,rhs) = dest_eq eq; 
    74     val (name,args) = strip_comb lhs; 
    75     val (ls',rest)  = take_prefix is_Free args; 
    76     val (middle,rs') = take_suffix is_Free rest;
    77     val rpos = length ls';
    78     val (c,cargs') = strip_comb (hd middle)
    79       handle LIST "hd" => raise RecError "constructor missing";
    80     val (ls,cargs,rs) = (map dest_Free ls', map dest_Free cargs'
    81                          , map dest_Free rs')
    82       handle TERM ("dest_Free",_) => 
    83           raise RecError "constructor has illegal argument in pattern";
    84   in 
    85     if length middle > 1 then 
    86       raise RecError "more than one non-variable in pattern"
    87     else if not(null(findrep (map fst (ls @ rs @ cargs)))) then 
    88       raise RecError "repeated variable name in pattern" 
    89          else (fst(dest_Const name) handle TERM _ => 
    90                raise RecError "function is not declared as constant in theory"
    91                  ,rpos,ls,fst( dest_Const c),cargs,rs,rhs)
    92   end; 
    93 
    94 (* check function specified for all constructors and sort function terms *)
    95 
    96 fun check_and_sort (n,its) = 
    97   if length its = n 
    98     then map snd (mysort (fn ((i : int,_),(j,_)) => i<j) its)
    99   else raise error "Primrec definition error:\n\
   100    \Please give an equation for every constructor";
   101 
   102 (* translate rec equations into function arguments suitable for rec comb *)
   103 (* theory parameter needed for printing error messages                   *) 
   104 
   105 fun trans_recs _ _ [] = error("No primrec equations.")
   106   | trans_recs thy cs' (eq1::eqs) = 
   107     let val (name1,rpos1,ls1,_,_,_,_) = dest_rec eq1
   108       handle RecError s =>
   109         error("Primrec definition error: " ^ s ^ ":\n" 
   110               ^ "   " ^ Sign.string_of_term (sign_of thy) eq1);
   111       val tcs = map (fn (_,c,T,_,_) => (c,T)) cs';  
   112       val cs = map fst tcs;
   113       fun trans_recs' _ [] = []
   114         | trans_recs' cis (eq::eqs) = 
   115           let val (name,rpos,ls,c,cargs,rs,rhs) = dest_rec eq; 
   116             val tc = assoc(tcs,c);
   117             val i = (1 + find (c,cs))  handle LIST "find" => 0; 
   118           in
   119           if name <> name1 then 
   120             raise RecError "function names inconsistent"
   121           else if rpos <> rpos1 then 
   122             raise RecError "position of rec. argument inconsistent"
   123           else if i = 0 then 
   124             raise RecError "illegal argument in pattern" 
   125           else if i mem cis then
   126             raise RecError "constructor already occured as pattern "
   127                else (i,abst_rec (name,rpos,the tc,ls,cargs,rs,rhs))
   128                      :: trans_recs' (i::cis) eqs 
   129           end
   130           handle RecError s =>
   131                 error("Primrec definition error\n" ^ s ^ "\n" 
   132                       ^ "   " ^ Sign.string_of_term (sign_of thy) eq);
   133     in (  name1, ls1
   134         , check_and_sort (length cs, trans_recs' [] (eq1::eqs)))
   135     end ;
   136 
   137 in
   138   fun add_datatype (typevars, tname, cons_list') thy = 
   139     let
   140       fun typid(dtRek(_,id)) = id
   141         | typid(dtVar s) = implode (tl (explode s))
   142         | typid(dtTyp(_,id)) = id;
   143 
   144       fun index_vnames(vn::vns,tab) =
   145             (case assoc(tab,vn) of
   146                None => if vn mem vns
   147                        then (vn^"1") :: index_vnames(vns,(vn,2)::tab)
   148                        else vn :: index_vnames(vns,tab)
   149              | Some(i) => (vn^(string_of_int i)) ::
   150                           index_vnames(vns,(vn,i+1)::tab))
   151         | index_vnames([],tab) = [];
   152 
   153       fun mk_var_names types = index_vnames(map typid types,[]);
   154 
   155       (*search for free type variables and convert recursive *)
   156       fun analyse_types (cons, types, syn) =
   157         let fun analyse(t as dtVar v) =
   158                   if t mem typevars then t
   159                   else error ("Free type variable " ^ v ^ " on rhs.")
   160               | analyse(dtTyp(typl,s)) =
   161                   if tname <> s then dtTyp(analyses typl, s)
   162                   else if typevars = typl then dtRek(typl, s)
   163                        else error (s ^ " used in different ways")
   164               | analyse(dtRek _) = raise Impossible
   165             and analyses ts = map analyse ts;
   166         in (cons, Syntax.const_name cons syn, analyses types,
   167             mk_var_names types, syn)
   168         end;
   169 
   170      (*test if all elements are recursive, i.e. if the type is empty*)
   171       
   172       fun non_empty (cs : ('a * 'b * dt_type list * 'c *'d) list) = 
   173         not(forall (exists is_dtRek o #3) cs) orelse
   174         error("Empty datatype not allowed!");
   175 
   176       val cons_list = map analyse_types cons_list';
   177       val dummy = non_empty cons_list;
   178       val num_of_cons = length cons_list;
   179 
   180      (* Auxiliary functions to construct argument and equation lists *)
   181 
   182      (*generate 'var_n, ..., var_m'*)
   183       fun Args(var, delim, n, m) = 
   184         space_implode delim (map (fn n => var^string_of_int(n)) (n upto m));
   185 
   186       fun C_exp name vns = name ^ opt_parens(space_implode ") (" vns);
   187 
   188      (*Arg_eqs([x1,...,xn],[y1,...,yn]) = "x1 = y1 & ... & xn = yn" *)
   189       fun arg_eqs vns vns' =
   190         let fun mkeq(x,x') = x ^ "=" ^ x'
   191         in space_implode " & " (map mkeq (vns~~vns')) end;
   192 
   193      (*Pretty printers for type lists;
   194        pp_typlist1: parentheses, pp_typlist2: brackets*)
   195       fun pp_typ (dtVar s) = "(" ^ s ^ "::term)"
   196         | pp_typ (dtTyp (typvars, id)) =
   197           if null typvars then id else (pp_typlist1 typvars) ^ id
   198         | pp_typ (dtRek (typvars, id)) = (pp_typlist1 typvars) ^ id
   199       and
   200         pp_typlist' ts = commas (map pp_typ ts)
   201       and
   202         pp_typlist1 ts = if null ts then "" else parens (pp_typlist' ts);
   203 
   204       fun pp_typlist2 ts = if null ts then "" else brackets (pp_typlist' ts);
   205 
   206      (* Generate syntax translation for case rules *)
   207       fun calc_xrules c_nr y_nr ((_, name, _, vns, _) :: cs) = 
   208         let val arity = length vns;
   209           val body  = "z" ^ string_of_int(c_nr);
   210           val args1 = if arity=0 then ""
   211                       else " " ^ Args ("y", " ", y_nr, y_nr+arity-1);
   212           val args2 = if arity=0 then ""
   213                       else "(% " ^ Args ("y", " ", y_nr, y_nr+arity-1) 
   214                         ^ ". ";
   215           val (rest1,rest2) = 
   216             if null cs then ("","")
   217             else let val (h1, h2) = calc_xrules (c_nr+1) (y_nr+arity) cs
   218             in (" | " ^ h1, " " ^ h2) end;
   219         in (name ^ args1 ^ " => " ^ body ^ rest1,
   220             args2 ^ body ^ (if args2 = "" then "" else ")") ^ rest2)
   221         end
   222         | calc_xrules _ _ [] = raise Impossible;
   223       
   224       val xrules =
   225         let val (first_part, scnd_part) = calc_xrules 1 1 cons_list
   226         in [("logic", "case x of " ^ first_part) <->
   227              ("logic", tname ^ "_case " ^ scnd_part ^ " x")]
   228         end;
   229 
   230      (*type declarations for constructors*)
   231       fun const_type (id, _, typlist, _, syn) =
   232         (id,  
   233          (if null typlist then "" else pp_typlist2 typlist ^ " => ") ^
   234             pp_typlist1 typevars ^ tname, syn);
   235 
   236 
   237       fun assumpt (dtRek _ :: ts, v :: vs ,found) =
   238         let val h = if found then ";P(" ^ v ^ ")" else "[| P(" ^ v ^ ")"
   239         in h ^ (assumpt (ts, vs, true)) end
   240         | assumpt (t :: ts, v :: vs, found) = assumpt (ts, vs, found)
   241       | assumpt ([], [], found) = if found then "|] ==>" else ""
   242         | assumpt _ = raise Impossible;
   243 
   244       fun t_inducting ((_, name, types, vns, _) :: cs) =
   245         let
   246           val h = if null types then " P(" ^ name ^ ")"
   247                   else " !!" ^ (space_implode " " vns) ^ "." ^
   248                     (assumpt (types, vns, false)) ^
   249                     "P(" ^ C_exp name vns ^ ")";
   250           val rest = t_inducting cs;
   251         in if rest = "" then h else h ^ "; " ^ rest end
   252         | t_inducting [] = "";
   253 
   254       fun t_induct cl typ_name =
   255         "[|" ^ t_inducting cl ^ "|] ==> P(" ^ typ_name ^ ")";
   256 
   257       fun gen_typlist typevar f ((_, _, ts, _, _) :: cs) =
   258         let val h = if (length ts) > 0
   259                       then pp_typlist2(f ts) ^ "=>"
   260                     else ""
   261         in h ^ typevar ^  "," ^ (gen_typlist typevar f cs) end
   262         | gen_typlist _ _ [] = "";
   263 
   264 
   265 (* -------------------------------------------------------------------- *)
   266 (* The case constant and rules                                          *)
   267                 
   268       val t_case = tname ^ "_case";
   269 
   270       fun case_rule n (id, name, _, vns, _) =
   271         let val args = if vns = [] then "" else " " ^ space_implode " " vns
   272         in (t_case ^ "_" ^ id,
   273             t_case ^ " " ^ Args("f", " ", 1, num_of_cons)
   274             ^ " (" ^ name ^ args ^ ") = f"^string_of_int(n) ^ args)
   275         end
   276 
   277       fun case_rules n (c :: cs) = case_rule n c :: case_rules(n+1) cs
   278         | case_rules _ [] = [];
   279 
   280       val datatype_arity = length typevars;
   281 
   282       val types = [(tname, datatype_arity, NoSyn)];
   283 
   284       val arities = 
   285         let val term_list = replicate datatype_arity termS;
   286         in [(tname, term_list, termS)] 
   287         end;
   288 
   289       val datatype_name = pp_typlist1 typevars ^ tname;
   290 
   291       val new_tvar_name = variant (map (fn dtVar s => s) typevars) "'z";
   292 
   293       val case_const =
   294         (t_case,
   295          "[" ^ gen_typlist new_tvar_name I cons_list 
   296          ^  pp_typlist1 typevars ^ tname ^ "] =>" ^ new_tvar_name^"::term",
   297          NoSyn);
   298 
   299       val rules_case = case_rules 1 cons_list;
   300 
   301 (* -------------------------------------------------------------------- *)
   302 (* The prim-rec combinator                                              *) 
   303 
   304       val t_rec = tname ^ "_rec"
   305 
   306 (* adding type variables for dtRek types to end of list of dt_types      *)   
   307 
   308       fun add_reks ts = 
   309         ts @ map (fn _ => dtVar new_tvar_name) (filter is_dtRek ts); 
   310 
   311 (* positions of the dtRek types in a list of dt_types, starting from 1  *)
   312       fun rek_vars ts vns = map snd (filter (is_dtRek o fst) (ts ~~ vns))
   313 
   314       fun rec_rule n (id,name,ts,vns,_) = 
   315         let val args = opt_parens(space_implode ") (" vns)
   316           val fargs = opt_parens(Args("f", ") (", 1, num_of_cons))
   317           fun rarg vn = t_rec ^ fargs ^ " (" ^ vn ^ ")"
   318           val rargs = opt_parens(space_implode ") ("
   319                                  (map rarg (rek_vars ts vns)))
   320         in
   321           (t_rec ^ "_" ^ id,
   322            t_rec ^ fargs ^ " (" ^ name ^ args ^ ") = f"
   323            ^ string_of_int(n) ^ args ^ rargs)
   324         end
   325 
   326       fun rec_rules n (c::cs) = rec_rule n c :: rec_rules (n+1) cs 
   327         | rec_rules _ [] = [];
   328 
   329       val rec_const =
   330         (t_rec,
   331          "[" ^ (gen_typlist new_tvar_name add_reks cons_list) 
   332          ^ (pp_typlist1 typevars) ^ tname ^ "] =>" ^ new_tvar_name^"::term",
   333          NoSyn);
   334 
   335       val rules_rec = rec_rules 1 cons_list
   336 
   337 (* -------------------------------------------------------------------- *)
   338       val consts = 
   339         map const_type cons_list
   340         @ (if num_of_cons < dtK then []
   341            else [(tname ^ "_ord", datatype_name ^ "=>nat", NoSyn)])
   342         @ [case_const,rec_const];
   343 
   344 
   345       fun Ci_ing ((id, name, _, vns, _) :: cs) =
   346            if null vns then Ci_ing cs
   347            else let val vns' = variantlist(vns,vns)
   348                 in ("inject_" ^ id,
   349                     "(" ^ (C_exp name vns) ^ "=" ^ (C_exp name vns')
   350                     ^ ") = (" ^ (arg_eqs vns vns') ^ ")") :: (Ci_ing cs)
   351                 end
   352         | Ci_ing [] = [];
   353 
   354       fun Ci_negOne (id1,name1,_,vns1,_) (id2,name2,_,vns2,_) =
   355             let val vns2' = variantlist(vns2,vns1)
   356                 val ax = C_exp name1 vns1 ^ "~=" ^ C_exp name2 vns2'
   357         in (id1 ^ "_not_" ^ id2, ax) end;
   358 
   359       fun Ci_neg1 [] = []
   360         | Ci_neg1 (c1::cs) = (map (Ci_negOne c1) cs) @ Ci_neg1 cs;
   361 
   362       fun suc_expr n = 
   363         if n=0 then "0" else "Suc(" ^ suc_expr(n-1) ^ ")";
   364 
   365       fun Ci_neg2() =
   366         let val ord_t = tname ^ "_ord";
   367           val cis = cons_list ~~ (0 upto (num_of_cons - 1))
   368           fun Ci_neg2equals ((id, name, _, vns, _), n) =
   369             let val ax = ord_t ^ "(" ^ (C_exp name vns) ^ ") = " ^ (suc_expr n)
   370             in (ord_t ^ "_" ^ id, ax) end
   371         in (ord_t ^ "_distinct", ord_t^"(x) ~= "^ord_t^"(y) ==> x ~= y") ::
   372           (map Ci_neg2equals cis)
   373         end;
   374 
   375       val rules_distinct = if num_of_cons < dtK then Ci_neg1 cons_list
   376                            else Ci_neg2();
   377 
   378       val rules_inject = Ci_ing cons_list;
   379 
   380       val rule_induct = (tname ^ "_induct", t_induct cons_list tname);
   381 
   382       val rules = rule_induct ::
   383         (rules_inject @ rules_distinct @ rules_case @ rules_rec);
   384 
   385       fun add_primrec eqns thy =
   386         let val rec_comb = Const(t_rec,dummyT)
   387           val teqns = map (fn neq => snd(read_axm (sign_of thy) neq)) eqns
   388           val (fname,ls,fns) = trans_recs thy cons_list teqns
   389           val rhs = 
   390             list_abs_free
   391             (ls @ [(tname,dummyT)]
   392              ,list_comb(rec_comb
   393                         , fns @ map Bound (0 ::(length ls downto 1))));
   394           val sg = sign_of thy;
   395           val defpair = (fname ^ "_" ^ tname ^ "_def",
   396                          Logic.mk_equals (Const(fname,dummyT), rhs))
   397           val defpairT as (_, _ $ Const(_,T) $ _ ) = inferT_axm sg defpair;
   398           val varT = Type.varifyT T;
   399           val ftyp = the (Sign.const_type sg fname);
   400         in add_defs_i [defpairT] thy end;
   401 
   402     in
   403       datatypes := map (fn (x,_,_) => x) cons_list' @ (!datatypes);
   404       (thy |> add_types types
   405            |> add_arities arities
   406            |> add_consts consts
   407            |> add_trrules xrules
   408            |> add_axioms rules, add_primrec)
   409     end
   410 end
   411 end
   412 
   413 (*
   414 Informal description of functions used in datatype.ML for the Isabelle/HOL
   415 implementation of prim. rec. function definitions. (N. Voelker, Feb. 1995) 
   416 
   417 * subst_apps (fname,rpos) pairs t:
   418    substitute the term 
   419        fname(ls,xk,rs) 
   420    by 
   421       yk(ls,rs) 
   422    in t for (xk,yk) in pairs, where rpos = length ls. 
   423    Applied with : 
   424      fname = function name 
   425      rpos = position of recursive argument 
   426      pairs = list of pairs (xk,yk), where 
   427           xk are the rec. arguments of the constructor in the pattern,
   428           yk is a variable with name derived from xk 
   429      t = rhs of equation 
   430 
   431 * abst_rec (fname,rpos,tc,ls,cargs,rs,rhs)
   432   - filter recursive arguments from constructor arguments cargs,
   433   - perform substitutions on rhs, 
   434   - derive list subs of new variable names yk for use in subst_apps, 
   435   - abstract rhs with respect to cargs, subs, ls and rs. 
   436 
   437 * dest_eq t 
   438   destruct a term denoting an equation into lhs and rhs. 
   439 
   440 * dest_req eq 
   441   destruct an equation of the form 
   442       name (vl1..vlrpos, Ci(vi1..vin), vr1..vrn) = rhs
   443   into 
   444   - function name  (name) 
   445   - position of the first non-variable parameter  (rpos)
   446   - the list of first rpos parameters (ls = [vl1..vlrpos]) 
   447   - the constructor (fst( dest_Const c) = Ci)
   448   - the arguments of the constructor (cargs = [vi1..vin])
   449   - the rest of the variables in the pattern (rs = [vr1..vrn])
   450   - the right hand side of the equation (rhs).  
   451  
   452 * check_and_sort (n,its)
   453   check that  n = length its holds, and sort elements of its by 
   454   first component. 
   455 
   456 * trans_recs thy cs' (eq1::eqs)
   457   destruct eq1 into name1, rpos1, ls1, etc.. 
   458   get constructor list with and without type (tcs resp. cs) from cs',  
   459   for every equation:  
   460     destruct it into (name,rpos,ls,c,cargs,rs,rhs)
   461     get typed constructor tc from c and tcs 
   462     determine the index i of the constructor 
   463     check function name and position of rec. argument by comparison
   464     with first equation 
   465     check for repeated variable names in pattern
   466     derive function term f_i which is used as argument of the rec. combinator
   467     sort the terms f_i according to i and return them together
   468       with the function name and the parameter of the definition (ls). 
   469 
   470 * Application:
   471 
   472   The rec. combinator is applied to the function terms resulting from
   473   trans_rec. This results in a function which takes the recursive arg. 
   474   as first parameter and then the arguments corresponding to ls. The
   475   order of parameters is corrected by setting the rhs equal to 
   476 
   477   list_abs_free
   478             (ls @ [(tname,dummyT)]
   479              ,list_comb(rec_comb
   480                         , fns @ map Bound (0 ::(length ls downto 1))));
   481 
   482   Note the de-Bruijn indices counting the number of lambdas between the
   483   variable and its binding. 
   484 *)