src/HOL/Tools/primrec_package.ML
author wenzelm
Sat Jul 01 19:55:22 2000 +0200 (2000-07-01)
changeset 9230 17ae63f82ad8
parent 8973 ac448cd43452
child 9315 f793f05024f6
permissions -rw-r--r--
GPLed;
     1 (*  Title:      HOL/Tools/primrec_package.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer and Norbert Voelker
     4     Copyright   1998  TU Muenchen
     5 
     6 Package for defining functions on datatypes by primitive recursion.
     7 *)
     8 
     9 signature PRIMREC_PACKAGE =
    10 sig
    11   val quiet_mode: bool ref
    12   val print_primrecs: theory -> unit
    13   val get_primrec: theory -> string -> (string * thm list) list option
    14   val add_primrec: string -> ((bstring * string) * Args.src list) list
    15     -> theory -> theory * thm list
    16   val add_primrec_i: string -> ((bstring * term) * theory attribute list) list
    17     -> theory -> theory * thm list
    18   val setup: (theory -> theory) list
    19 end;
    20 
    21 structure PrimrecPackage : PRIMREC_PACKAGE =
    22 struct
    23 
    24 open DatatypeAux;
    25 
    26 exception RecError of string;
    27 
    28 fun primrec_err s = error ("Primrec definition error:\n" ^ s);
    29 fun primrec_eq_err sign s eq =
    30   primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term sign eq));
    31 
    32 
    33 (* messages *)
    34 
    35 val quiet_mode = ref false;
    36 fun message s = if ! quiet_mode then () else writeln s;
    37 
    38 
    39 (** theory data **)
    40 
    41 (* data kind 'HOL/primrec' *)
    42 
    43 structure PrimrecArgs =
    44 struct
    45   val name = "HOL/primrec";
    46   type T = (string * thm list) list Symtab.table;
    47 
    48   val empty = Symtab.empty;
    49   val copy = I;
    50   val prep_ext = I;
    51   val merge: T * T -> T = Symtab.merge (K true);
    52 
    53   fun print sg tab =
    54     Pretty.writeln (Pretty.strs ("primrecs:" ::
    55       map #1 (Sign.cond_extern_table sg Sign.constK tab)));
    56 end;
    57 
    58 structure PrimrecData = TheoryDataFun(PrimrecArgs);
    59 val print_primrecs = PrimrecData.print;
    60 
    61 
    62 (* get and put data *)
    63 
    64 fun get_primrec thy name = Symtab.lookup (PrimrecData.get thy, name);
    65 
    66 fun put_primrec name info thy =
    67   let val tab = PrimrecData.get thy
    68   in 
    69     PrimrecData.put (case Symtab.lookup (tab, name) of
    70        None => Symtab.update_new ((name, [info]), tab)
    71      | Some infos => Symtab.update ((name, info::infos), tab)) thy end;
    72 
    73 
    74 (* preprocessing of equations *)
    75 
    76 fun process_eqn sign (eq, rec_fns) = 
    77   let
    78     val (lhs, rhs) = 
    79 	if null (term_vars eq) then
    80 	    HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    81 	      handle TERM _ => raise RecError "not a proper equation"
    82 	else raise RecError "illegal schematic variable(s)";
    83 
    84     val (recfun, args) = strip_comb lhs;
    85     val (fname, _) = dest_Const recfun handle TERM _ => 
    86       raise RecError "function is not declared as constant in theory";
    87 
    88     val (ls', rest)  = take_prefix is_Free args;
    89     val (middle, rs') = take_suffix is_Free rest;
    90     val rpos = length ls';
    91 
    92     val (constr, cargs') = if null middle then raise RecError "constructor missing"
    93       else strip_comb (hd middle);
    94     val (cname, T) = dest_Const constr
    95       handle TERM _ => raise RecError "ill-formed constructor";
    96     val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    97       raise RecError "cannot determine datatype associated with function"
    98 
    99     val (ls, cargs, rs) = (map dest_Free ls', 
   100 			   map dest_Free cargs', 
   101 			   map dest_Free rs')
   102       handle TERM _ => raise RecError "illegal argument in pattern";
   103     val lfrees = ls @ rs @ cargs;
   104 
   105     val _ = case duplicates lfrees of
   106 	       [] => ()
   107 	     | vars => raise RecError("repeated variable names in pattern: " ^ 
   108 				      commas_quote (map #1 vars))
   109  
   110     val _ =  case (map dest_Free (term_frees rhs)) \\ lfrees of
   111 		[] => ()
   112 	      | vars => raise RecError 
   113 		      ("extra variables on rhs: " ^ 
   114 		       commas_quote (map #1 vars))
   115   in
   116     if length middle > 1 then 
   117       raise RecError "more than one non-variable in pattern"
   118     else (case assoc (rec_fns, fname) of
   119         None =>
   120           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
   121       | Some (_, rpos', eqns) =>
   122           if is_some (assoc (eqns, cname)) then
   123             raise RecError "constructor already occurred as pattern"
   124           else if rpos <> rpos' then
   125             raise RecError "position of recursive argument inconsistent"
   126           else
   127             overwrite (rec_fns, 
   128 		       (fname, 
   129 			(tname, rpos,
   130 			 (cname, (ls, cargs, rs, rhs, eq))::eqns))))
   131   end
   132   handle RecError s => primrec_eq_err sign s eq;
   133 
   134 fun process_fun sign descr rec_eqns ((i, fname), (fnames, fnss)) =
   135   let
   136     val (_, (tname, _, constrs)) = nth_elem (i, descr);
   137 
   138     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
   139 
   140     fun subst [] x = x
   141       | subst subs (fs, Abs (a, T, t)) =
   142           let val (fs', t') = subst subs (fs, t)
   143           in (fs', Abs (a, T, t')) end
   144       | subst subs (fs, t as (_ $ _)) =
   145           let val (f, ts) = strip_comb t;
   146           in
   147             if is_Const f andalso (fst (dest_Const f)) mem (map fst rec_eqns) then
   148               let
   149                 val (fname', _) = dest_Const f;
   150                 val (_, rpos, _) = the (assoc (rec_eqns, fname'));
   151                 val ls = take (rpos, ts);
   152                 val rest = drop (rpos, ts);
   153                 val (x', rs) = (hd rest, tl rest)
   154                   handle LIST _ => raise RecError ("not enough arguments\
   155                    \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
   156                 val (x, xs) = strip_comb x'
   157               in 
   158                 (case assoc (subs, x) of
   159                     None =>
   160                       let
   161                         val (fs', ts') = foldl_map (subst subs) (fs, ts)
   162                       in (fs', list_comb (f, ts')) end
   163                   | Some (i', y) =>
   164                       let
   165                         val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
   166                         val fs'' = process_fun sign descr rec_eqns ((i', fname'), fs')
   167                       in (fs'', list_comb (y, ts'))
   168                       end)
   169               end
   170             else
   171               let
   172                 val (fs', f'::ts') = foldl_map (subst subs) (fs, f::ts)
   173               in (fs', list_comb (f', ts')) end
   174           end
   175       | subst _ x = x;
   176 
   177     (* translate rec equations into function arguments suitable for rec comb *)
   178 
   179     fun trans eqns ((cname, cargs), (fnames', fnss', fns)) =
   180       (case assoc (eqns, cname) of
   181           None => (warning ("no equation for constructor " ^ quote cname ^
   182             "\nin definition of function " ^ quote fname);
   183               (fnames', fnss', (Const ("arbitrary", dummyT))::fns))
   184         | Some (ls, cargs', rs, rhs, eq) =>
   185             let
   186               fun rec_index (DtRec k) = k
   187                 | rec_index (DtType ("fun", [_, DtRec k])) = k;
   188 
   189               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   190               val rargs = map fst recs;
   191               val subs = map (rpair dummyT o fst) 
   192 		             (rev (rename_wrt_term rhs rargs));
   193               val ((fnames'', fnss''), rhs') = 
   194 		  (subst (map (fn ((x, y), z) =>
   195 			       (Free x, (rec_index y, Free z)))
   196 			  (recs ~~ subs))
   197 		   ((fnames', fnss'), rhs))
   198                   handle RecError s => primrec_eq_err sign s eq
   199             in (fnames'', fnss'', 
   200 		(list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   201             end)
   202 
   203   in (case assoc (fnames, i) of
   204       None =>
   205         if exists (equal fname o snd) fnames then
   206           raise RecError ("inconsistent functions for datatype " ^ quote tname)
   207         else
   208           let
   209             val (_, _, eqns) = the (assoc (rec_eqns, fname));
   210             val (fnames', fnss', fns) = foldr (trans eqns)
   211               (constrs, ((i, fname)::fnames, fnss, []))
   212           in
   213             (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   214           end
   215     | Some fname' =>
   216         if fname = fname' then (fnames, fnss)
   217         else raise RecError ("inconsistent functions for datatype " ^ quote tname))
   218   end;
   219 
   220 
   221 (* prepare functions needed for definitions *)
   222 
   223 fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) =
   224   case assoc (fns, i) of
   225      None =>
   226        let
   227          val dummy_fns = map (fn (_, cargs) => Const ("arbitrary",
   228            replicate ((length cargs) + (length (filter is_rec_type cargs)))
   229              dummyT ---> HOLogic.unitT)) constrs;
   230          val _ = warning ("No function definition for datatype " ^ quote tname)
   231        in
   232          (dummy_fns @ fs, defs)
   233        end
   234    | Some (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname)::defs);
   235 
   236 
   237 (* make definition *)
   238 
   239 fun make_def sign fs (fname, ls, rec_name, tname) =
   240   let
   241     val rhs = foldr (fn (T, t) => Abs ("", T, t)) 
   242 	            ((map snd ls) @ [dummyT],
   243 		     list_comb (Const (rec_name, dummyT),
   244 				fs @ map Bound (0 ::(length ls downto 1))));
   245     val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
   246 		   Logic.mk_equals (Const (fname, dummyT), rhs))
   247   in Theory.inferT_axm sign defpair end;
   248 
   249 
   250 (* find datatypes which contain all datatypes in tnames' *)
   251 
   252 fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
   253   | find_dts dt_info tnames' (tname::tnames) =
   254       (case Symtab.lookup (dt_info, tname) of
   255           None => primrec_err (quote tname ^ " is not a datatype")
   256         | Some dt =>
   257             if tnames' subset (map (#1 o snd) (#descr dt)) then
   258               (tname, dt)::(find_dts dt_info tnames' tnames)
   259             else find_dts dt_info tnames' tnames);
   260 
   261 fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns =
   262   let
   263     fun constrs_of (_, (_, _, cs)) =
   264       map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs;
   265     val params_of = Library.assocs (flat (map constrs_of rec_eqns));
   266   in
   267     induction
   268     |> RuleCases.rename_params (map params_of (flat (map (map #1 o #3 o #2) descr)))
   269     |> RuleCases.name (RuleCases.get induction)
   270   end;
   271 
   272 fun add_primrec_i alt_name eqns_atts thy =
   273   let
   274     val (eqns, atts) = split_list eqns_atts;
   275     val sg = Theory.sign_of thy;
   276     val dt_info = DatatypePackage.get_datatypes thy;
   277     val rec_eqns = foldr (process_eqn sg) (map snd eqns, []);
   278     val tnames = distinct (map (#1 o snd) rec_eqns);
   279     val dts = find_dts dt_info tnames tnames;
   280     val main_fns = 
   281 	map (fn (tname, {index, ...}) =>
   282 	     (index, 
   283 	      fst (the (find_first (fn f => #1 (snd f) = tname) rec_eqns))))
   284 	dts;
   285     val {descr, rec_names, rec_rewrites, ...} = 
   286 	if null dts then
   287 	    primrec_err ("datatypes " ^ commas_quote tnames ^ 
   288 			 "\nare not mutually recursive")
   289 	else snd (hd dts);
   290     val (fnames, fnss) = foldr (process_fun sg descr rec_eqns)
   291 	                       (main_fns, ([], []));
   292     val (fs, defs) = foldr (get_fns fnss) (descr ~~ rec_names, ([], []));
   293     val defs' = map (make_def sg fs) defs;
   294     val names1 = map snd fnames;
   295     val names2 = map fst rec_eqns;
   296     val primrec_name =
   297       if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;
   298     val (thy', defs_thms') = thy |> Theory.add_path primrec_name |>
   299       (if eq_set (names1, names2) then (PureThy.add_defs_i o map Thm.no_attributes) defs'
   300        else primrec_err ("functions " ^ commas_quote names2 ^
   301          "\nare not mutually recursive"));
   302     val rewrites = o_def :: (map mk_meta_eq rec_rewrites) @ defs_thms';
   303     val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names1 ^ " ...");
   304     val simps = map (fn (_, t) => prove_goalw_cterm rewrites (cterm_of (Theory.sign_of thy') t)
   305         (fn _ => [rtac refl 1])) eqns;
   306     val (thy'', [simps']) = thy'
   307       |> PureThy.add_thmss [(("simps", simps), [Simplifier.simp_add_global])]
   308       |>> (#1 o PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts))
   309       |>> (#1 o PureThy.add_thms [(("induct", prepare_induct (#2 (hd dts)) rec_eqns), [])])
   310       |>> Theory.parent_path
   311   in
   312     (foldl (fn (thy, (fname, _, _, tname)) =>
   313        put_primrec fname (tname, simps') thy) (thy'', defs), simps')
   314   end;
   315 
   316 
   317 fun add_primrec alt_name eqns thy =
   318   let
   319     val sign = Theory.sign_of thy;
   320     val ((names, strings), srcss) = apfst split_list (split_list eqns);
   321     val atts = map (map (Attrib.global_attribute thy)) srcss;
   322     val eqn_ts = map (term_of o Thm.read_cterm sign o rpair propT) strings;
   323     val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq)))
   324       handle TERM _ => primrec_eq_err sign "not a proper equation" eq) eqn_ts;
   325     val (_, eqn_ts') = InductivePackage.unify_consts (sign_of thy) rec_ts eqn_ts
   326   in
   327     add_primrec_i alt_name (names ~~ eqn_ts' ~~ atts) thy
   328   end;
   329 
   330 
   331 (** package setup **)
   332 
   333 val setup = [PrimrecData.init];
   334 
   335 (* outer syntax *)
   336 
   337 local structure P = OuterParse and K = OuterSyntax.Keyword in
   338 
   339 val primrec_decl =
   340   Scan.optional (P.$$$ "(" |-- P.name --| P.$$$ ")") "" --
   341     Scan.repeat1 (P.opt_thm_name ":" -- P.prop --| P.marg_comment);
   342 
   343 val primrecP =
   344   OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
   345     (primrec_decl >> (fn (alt_name, eqns) =>
   346       Toplevel.theory (#1 o add_primrec alt_name (map P.triple_swap eqns))));
   347 
   348 val _ = OuterSyntax.add_parsers [primrecP];
   349 
   350 end;
   351 
   352 
   353 end;