--- a/src/HOL/Tools/primrec.ML Fri Oct 01 14:15:49 2010 +0200
+++ b/src/HOL/Tools/primrec.ML Fri Oct 01 16:44:13 2010 +0200
@@ -1,8 +1,9 @@
(* Title: HOL/Tools/primrec.ML
- Author: Stefan Berghofer, TU Muenchen; Norbert Voelker, FernUni Hagen;
- Florian Haftmann, TU Muenchen
+ Author: Norbert Voelker, FernUni Hagen
+ Author: Stefan Berghofer, TU Muenchen
+ Author: Florian Haftmann, TU Muenchen
-Package for defining functions on datatypes by primitive recursion.
+Primitive recursive functions on datatypes.
*)
signature PRIMREC =
@@ -45,15 +46,19 @@
val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
handle TERM _ => primrec_error "not a proper equation";
val (recfun, args) = strip_comb lhs;
- val fname = case recfun of Free (v, _) => if is_fixed v then v
+ val fname =
+ (case recfun of
+ Free (v, _) =>
+ if is_fixed v then v
else primrec_error "illegal head of function equation"
- | _ => primrec_error "illegal head of function equation";
+ | _ => primrec_error "illegal head of function equation");
val (ls', rest) = take_prefix is_Free args;
val (middle, rs') = take_suffix is_Free rest;
val rpos = length ls';
- val (constr, cargs') = if null middle then primrec_error "constructor missing"
+ val (constr, cargs') =
+ if null middle then primrec_error "constructor missing"
else strip_comb (hd middle);
val (cname, T) = dest_Const constr
handle TERM _ => primrec_error "ill-formed constructor";
@@ -73,11 +78,11 @@
else
(check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
check_vars "extra variables on rhs: "
- (map dest_Free (OldTerm.term_frees rhs) |> subtract (op =) lfrees
+ (Term.add_frees rhs [] |> subtract (op =) lfrees
|> filter_out (is_fixed o fst));
- case AList.lookup (op =) rec_fns fname of
+ (case AList.lookup (op =) rec_fns fname of
NONE =>
- (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns
+ (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))])) :: rec_fns
| SOME (_, rpos', eqns) =>
if AList.defined (op =) eqns cname then
primrec_error "constructor already occurred as pattern"
@@ -85,8 +90,8 @@
primrec_error "position of recursive argument inconsistent"
else
AList.update (op =)
- (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn))::eqns))
- rec_fns)
+ (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn)) :: eqns))
+ rec_fns))
end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
fun process_fun descr eqns (i, fname) (fnames, fnss) =
@@ -110,13 +115,15 @@
val (fname', _) = dest_Free f;
val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
val (ls, rs) = chop rpos ts
- val (x', rs') = case rs
- of x' :: rs => (x', rs)
- | [] => primrec_error ("not enough arguments in recursive application\n"
- ^ "of function " ^ quote fname' ^ " on rhs");
+ val (x', rs') =
+ (case rs of
+ x' :: rs => (x', rs)
+ | [] => primrec_error ("not enough arguments in recursive application\n" ^
+ "of function " ^ quote fname' ^ " on rhs"));
val (x, xs) = strip_comb x';
- in case AList.lookup (op =) subs x
- of NONE =>
+ in
+ (case AList.lookup (op =) subs x of
+ NONE =>
fs
|> fold_map (subst subs) ts
|-> (fn ts' => pair (list_comb (f, ts')))
@@ -124,12 +131,12 @@
fs
|> fold_map (subst subs) (xs @ ls @ rs')
||> process_fun descr eqns (i', fname')
- |-> (fn ts' => pair (list_comb (y, ts')))
+ |-> (fn ts' => pair (list_comb (y, ts'))))
end
else
fs
|> fold_map (subst subs) (f :: ts)
- |-> (fn (f'::ts') => pair (list_comb (f', ts')))
+ |-> (fn f' :: ts' => pair (list_comb (f', ts')))
end
| subst _ t fs = (t, fs);
@@ -137,23 +144,24 @@
fun trans eqns (cname, cargs) (fnames', fnss', fns) =
(case AList.lookup (op =) eqns cname of
- NONE => (warning ("No equation for constructor " ^ quote cname ^
- "\nin definition of function " ^ quote fname);
- (fnames', fnss', (Const ("HOL.undefined", dummyT))::fns))
- | SOME (ls, cargs', rs, rhs, eq) =>
- let
- val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
- val rargs = map fst recs;
- val subs = map (rpair dummyT o fst)
- (rev (Term.rename_wrt_term rhs rargs));
- val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
- (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss')
- handle PrimrecError (s, NONE) => primrec_error_eqn s eq
- in (fnames'', fnss'',
- (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
- end)
+ NONE => (warning ("No equation for constructor " ^ quote cname ^
+ "\nin definition of function " ^ quote fname);
+ (fnames', fnss', (Const (@{const_name undefined}, dummyT)) :: fns))
+ | SOME (ls, cargs', rs, rhs, eq) =>
+ let
+ val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
+ val rargs = map fst recs;
+ val subs = map (rpair dummyT o fst)
+ (rev (Term.rename_wrt_term rhs rargs));
+ val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
+ (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss')
+ handle PrimrecError (s, NONE) => primrec_error_eqn s eq
+ in (fnames'', fnss'',
+ (list_abs_free (cargs' @ subs @ ls @ rs, rhs')) :: fns)
+ end)
- in (case AList.lookup (op =) fnames i of
+ in
+ (case AList.lookup (op =) fnames i of
NONE =>
if exists (fn (_, v) => fname = v) fnames then
primrec_error ("inconsistent functions for datatype " ^ quote tname)
@@ -161,9 +169,9 @@
let
val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
- ((i, fname)::fnames, fnss, [])
+ ((i, fname) :: fnames, fnss, [])
in
- (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
+ (fnames', (i, (fname, #1 (snd (hd eqns)), fns)) :: fnss')
end
| SOME fname' =>
if fname = fname' then (fnames, fnss)
@@ -174,17 +182,17 @@
(* prepare functions needed for definitions *)
fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
- case AList.lookup (op =) fns i of
- NONE =>
- let
- val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined",
- replicate (length cargs + length (filter is_rec_type cargs))
- dummyT ---> HOLogic.unitT)) constrs;
- val _ = warning ("No function definition for datatype " ^ quote tname)
- in
- (dummy_fns @ fs, defs)
- end
- | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);
+ (case AList.lookup (op =) fns i of
+ NONE =>
+ let
+ val dummy_fns = map (fn (_, cargs) => Const (@{const_name undefined},
+ replicate (length cargs + length (filter is_rec_type cargs))
+ dummyT ---> HOLogic.unitT)) constrs;
+ val _ = warning ("No function definition for datatype " ^ quote tname)
+ in
+ (dummy_fns @ fs, defs)
+ end
+ | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs));
(* make definition *)
@@ -203,13 +211,13 @@
(* find datatypes which contain all datatypes in tnames' *)
fun find_dts (dt_info : info Symtab.table) _ [] = []
- | find_dts dt_info tnames' (tname::tnames) =
+ | find_dts dt_info tnames' (tname :: tnames) =
(case Symtab.lookup dt_info tname of
- NONE => primrec_error (quote tname ^ " is not a datatype")
- | SOME dt =>
- if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
- (tname, dt)::(find_dts dt_info tnames' tnames)
- else find_dts dt_info tnames' tnames);
+ NONE => primrec_error (quote tname ^ " is not a datatype")
+ | SOME dt =>
+ if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
+ (tname, dt) :: (find_dts dt_info tnames' tnames)
+ else find_dts dt_info tnames' tnames);
(* distill primitive definition(s) from primrec specification *)
@@ -231,7 +239,8 @@
val defs = map (make_def lthy fixes fs) raw_defs;
val names = map snd fnames;
val names_eqns = map fst eqns;
- val _ = if eq_set (op =) (names, names_eqns) then ()
+ val _ =
+ if eq_set (op =) (names, names_eqns) then ()
else primrec_error ("functions " ^ commas_quote names_eqns ^
"\nare not mutually recursive");
val rec_rewrites' = map mk_meta_eq rec_rewrites;
@@ -246,8 +255,9 @@
in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
in ((prefix, (fs, defs)), prove) end
handle PrimrecError (msg, some_eqn) =>
- error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
- of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
+ error ("Primrec definition error:\n" ^ msg ^
+ (case some_eqn of
+ SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
| NONE => ""));
@@ -306,12 +316,10 @@
(* outer syntax *)
-val primrec_decl = Parse.opt_target -- Parse.fixes -- Parse_Spec.where_alt_specs;
-
val _ =
- Outer_Syntax.command "primrec" "define primitive recursive functions on datatypes"
- Keyword.thy_decl
- (primrec_decl >> (fn ((opt_target, fixes), specs) =>
- Toplevel.local_theory opt_target (add_primrec_cmd fixes specs #> snd)));
+ Outer_Syntax.local_theory "primrec" "define primitive recursive functions on datatypes"
+ Keyword.thy_decl
+ (Parse.fixes -- Parse_Spec.where_alt_specs
+ >> (fn (fixes, specs) => add_primrec_cmd fixes specs #> snd));
end;