src/HOL/Tools/primrec_package.ML
author wenzelm
Thu Mar 11 13:20:35 1999 +0100 (1999-03-11)
changeset 6349 f7750d816c21
parent 6092 d9db67970c73
child 6359 6fdb0badc6f4
permissions -rw-r--r--
removed foo_build_completed -- now handled by session management (via usedir);
     1 (*  Title:      HOL/Tools/primrec_package.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer and Norbert Voelker
     4     Copyright   1998  TU Muenchen
     5 
     6 Package for defining functions on datatypes by primitive recursion
     7 *)
     8 
     9 signature PRIMREC_PACKAGE =
    10 sig
    11   val add_primrec_i : string -> (string * term) list ->
    12     theory -> theory * thm list
    13   val add_primrec : string -> (string * string) list ->
    14     theory -> theory * thm list
    15 end;
    16 
    17 structure PrimrecPackage : PRIMREC_PACKAGE =
    18 struct
    19 
    20 open DatatypeAux;
    21 
    22 exception RecError of string;
    23 
    24 fun primrec_err s = error ("Primrec definition error:\n" ^ s);
    25 fun primrec_eq_err sign s eq =
    26   primrec_err (s ^ "\nin equation\n" ^ Sign.string_of_term sign eq);
    27 
    28 (* preprocessing of equations *)
    29 
    30 fun process_eqn sign (eq, rec_fns) = 
    31   let
    32     val (lhs, rhs) = 
    33 	if null (term_vars eq) then
    34 	    HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    35 	      handle _ => raise RecError "not a proper equation"
    36 	else raise RecError "illegal schematic variable(s)";
    37 
    38     val (recfun, args) = strip_comb lhs;
    39     val (fname, _) = dest_Const recfun handle _ => 
    40       raise RecError "function is not declared as constant in theory";
    41 
    42     val (ls', rest)  = take_prefix is_Free args;
    43     val (middle, rs') = take_suffix is_Free rest;
    44     val rpos = length ls';
    45 
    46     val (constr, cargs') = if null middle then raise RecError "constructor missing"
    47       else strip_comb (hd middle);
    48     val (cname, T) = dest_Const constr
    49       handle _ => raise RecError "ill-formed constructor";
    50     val (tname, _) = dest_Type (body_type T) handle _ =>
    51       raise RecError "cannot determine datatype associated with function"
    52 
    53     val (ls, cargs, rs) = (map dest_Free ls', 
    54 			   map dest_Free cargs', 
    55 			   map dest_Free rs')
    56       handle _ => raise RecError "illegal argument in pattern";
    57     val lfrees = ls @ rs @ cargs;
    58 
    59   in
    60     if not (null (duplicates lfrees)) then 
    61       raise RecError "repeated variable name in pattern" 
    62     else if not ((map dest_Free (term_frees rhs)) subset lfrees) then
    63       raise RecError "extra variables on rhs"
    64     else if length middle > 1 then 
    65       raise RecError "more than one non-variable in pattern"
    66     else (case assoc (rec_fns, fname) of
    67         None =>
    68           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
    69       | Some (_, rpos', eqns) =>
    70           if is_some (assoc (eqns, cname)) then
    71             raise RecError "constructor already occurred as pattern"
    72           else if rpos <> rpos' then
    73             raise RecError "position of recursive argument inconsistent"
    74           else
    75             overwrite (rec_fns, 
    76 		       (fname, 
    77 			(tname, rpos,
    78 			 (cname, (ls, cargs, rs, rhs, eq))::eqns))))
    79   end
    80   handle RecError s => primrec_eq_err sign s eq;
    81 
    82 fun process_fun sign descr rec_eqns ((i, fname), (fnames, fnss)) =
    83   let
    84     val (_, (tname, _, constrs)) = nth_elem (i, descr);
    85 
    86     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
    87 
    88     fun subst [] x = x
    89       | subst subs (fs, Abs (a, T, t)) =
    90           let val (fs', t') = subst subs (fs, t)
    91           in (fs', Abs (a, T, t')) end
    92       | subst subs (fs, t as (_ $ _)) =
    93           let val (f, ts) = strip_comb t;
    94           in
    95             if is_Const f andalso (fst (dest_Const f)) mem (map fst rec_eqns) then
    96               let
    97                 val (fname', _) = dest_Const f;
    98                 val (_, rpos, _) = the (assoc (rec_eqns, fname'));
    99                 val ls = take (rpos, ts);
   100                 val rest = drop (rpos, ts);
   101                 val (x, rs) = (hd rest, tl rest)
   102                   handle _ => raise RecError ("not enough arguments\
   103                    \ in recursive application\nof function " ^ fname' ^ " on rhs")
   104               in 
   105                 (case assoc (subs, x) of
   106                     None =>
   107                       let
   108                         val (fs', ts') = foldl_map (subst subs) (fs, ts)
   109                       in (fs', list_comb (f, ts')) end
   110                   | Some (i', y) =>
   111                       let
   112                         val (fs', ts') = foldl_map (subst subs) (fs, ls @ rs);
   113                         val fs'' = process_fun sign descr rec_eqns ((i', fname'), fs')
   114                       in (fs'', list_comb (y, ts'))
   115                       end)
   116               end
   117             else
   118               let
   119                 val (fs', f'::ts') = foldl_map (subst subs) (fs, f::ts)
   120               in (fs', list_comb (f', ts')) end
   121           end
   122       | subst _ x = x;
   123 
   124     (* translate rec equations into function arguments suitable for rec comb *)
   125 
   126     fun trans eqns ((cname, cargs), (fnames', fnss', fns)) =
   127       (case assoc (eqns, cname) of
   128           None => (warning ("no equation for constructor " ^ cname ^
   129             "\nin definition of function " ^ fname);
   130               (fnames', fnss', (Const ("arbitrary", dummyT))::fns))
   131         | Some (ls, cargs', rs, rhs, eq) =>
   132             let
   133               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   134               val rargs = map fst recs;
   135               val subs = map (rpair dummyT o fst) 
   136 		             (rev (rename_wrt_term rhs rargs));
   137               val ((fnames'', fnss''), rhs') = 
   138 		  (subst (map (fn ((x, y), z) =>
   139 			       (Free x, (dest_DtRec y, Free z)))
   140 			  (recs ~~ subs))
   141 		   ((fnames', fnss'), rhs))
   142                   handle RecError s => primrec_eq_err sign s eq
   143             in (fnames'', fnss'', 
   144 		(list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   145             end)
   146 
   147   in (case assoc (fnames, i) of
   148       None =>
   149         if exists (equal fname o snd) fnames then
   150           raise RecError ("inconsistent functions for datatype " ^ tname)
   151         else
   152           let
   153             val (_, _, eqns) = the (assoc (rec_eqns, fname));
   154             val (fnames', fnss', fns) = foldr (trans eqns)
   155               (constrs, ((i, fname)::fnames, fnss, []))
   156           in
   157             (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   158           end
   159     | Some fname' =>
   160         if fname = fname' then (fnames, fnss)
   161         else raise RecError ("inconsistent functions for datatype " ^ tname))
   162   end;
   163 
   164 (* prepare functions needed for definitions *)
   165 
   166 fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) =
   167   case assoc (fns, i) of
   168      None =>
   169        let
   170          val dummy_fns = map (fn (_, cargs) => Const ("arbitrary",
   171            replicate ((length cargs) + (length (filter is_rec_type cargs)))
   172              dummyT ---> HOLogic.unitT)) constrs;
   173          val _ = warning ("no function definition for datatype " ^ tname)
   174        in
   175          (dummy_fns @ fs, defs)
   176        end
   177    | Some (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname)::defs);
   178 
   179 (* make definition *)
   180 
   181 fun make_def sign fs (fname, ls, rec_name, tname) =
   182   let
   183     val rhs = foldr (fn (T, t) => Abs ("", T, t)) 
   184 	            ((map snd ls) @ [dummyT],
   185 		     list_comb (Const (rec_name, dummyT),
   186 				fs @ map Bound (0 ::(length ls downto 1))));
   187     val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
   188 		   Logic.mk_equals (Const (fname, dummyT), rhs))
   189   in
   190     inferT_axm sign defpair
   191   end;
   192 
   193 (* find datatypes which contain all datatypes in tnames' *)
   194 
   195 fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
   196   | find_dts dt_info tnames' (tname::tnames) =
   197       (case Symtab.lookup (dt_info, tname) of
   198           None => primrec_err (tname ^ " is not a datatype")
   199         | Some dt =>
   200             if tnames' subset (map (#1 o snd) (#descr dt)) then
   201               (tname, dt)::(find_dts dt_info tnames' tnames)
   202             else find_dts dt_info tnames' tnames);
   203 
   204 fun add_primrec_i alt_name eqns thy =
   205   let
   206     val sg = sign_of thy;
   207     val dt_info = DatatypePackage.get_datatypes thy;
   208     val rec_eqns = foldr (process_eqn sg) (map snd eqns, []);
   209     val tnames = distinct (map (#1 o snd) rec_eqns);
   210     val dts = find_dts dt_info tnames tnames;
   211     val main_fns = 
   212 	map (fn (tname, {index, ...}) =>
   213 	     (index, 
   214 	      fst (the (find_first (fn f => #1 (snd f) = tname) rec_eqns))))
   215 	dts;
   216     val {descr, rec_names, rec_rewrites, ...} = 
   217 	if null dts then
   218 	    primrec_err ("datatypes " ^ commas tnames ^ 
   219 			 "\nare not mutually recursive")
   220 	else snd (hd dts);
   221     val (fnames, fnss) = foldr (process_fun sg descr rec_eqns)
   222 	                       (main_fns, ([], []));
   223     val (fs, defs) = foldr (get_fns fnss) (descr ~~ rec_names, ([], []));
   224     val defs' = map (make_def sg fs) defs;
   225     val names1 = map snd fnames;
   226     val names2 = map fst rec_eqns;
   227     val thy' = thy |>
   228       Theory.add_path (if alt_name = "" then (space_implode "_"
   229         (map (Sign.base_name o #1) defs)) else alt_name) |>
   230       (if eq_set (names1, names2) then Theory.add_defs_i defs'
   231        else primrec_err ("functions " ^ commas names2 ^
   232          "\nare not mutually recursive"));
   233     val rewrites = (map mk_meta_eq rec_rewrites) @ (map (get_axiom thy' o fst) defs');
   234     val _ = writeln ("Proving equations for primrec function(s)\n" ^
   235       commas names1 ^ " ...");
   236     val char_thms = map (fn (_, t) => prove_goalw_cterm rewrites (cterm_of (sign_of thy') t)
   237         (fn _ => [rtac refl 1])) eqns;
   238     val simps = char_thms;
   239     val thy'' = thy' |>
   240       PureThy.add_thmss [(("simps", simps), [Simplifier.simp_add_global])] |>
   241       PureThy.add_thms (map (rpair [])
   242         (filter_out (equal "" o fst) (map fst eqns ~~ simps))) |>
   243       Theory.parent_path;
   244   in
   245     (thy'', char_thms)
   246   end;
   247 
   248 fun add_primrec alt_name eqns thy =
   249   add_primrec_i alt_name (map (apsnd (readtm (sign_of thy) propT)) eqns) thy;
   250 
   251 end;